Unverified Commit 8e27663b authored by R3hankhan's avatar R3hankhan Committed by GitHub
Browse files

[CPU] Add head sizes 80 and 112 with vec16 fallback (#31968)


Signed-off-by: default avatarRehan Khan <Rehan.Khan7@ibm.com>
parent 7cdf7e2f
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#ifdef __aarch64__ #ifdef __aarch64__
#include "cpu_attn_neon.hpp" #include "cpu_attn_neon.hpp"
// NEON requires head_dim to be a multiple of 32
#define NEON_DISPATCH(...) \ #define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \ case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \ using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
...@@ -36,7 +37,9 @@ ...@@ -36,7 +37,9 @@
switch (HEAD_DIM) { \ switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
......
...@@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> { ...@@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
const int32_t q_heads_per_kv, const int64_t q_num_stride, const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, const float scale) { const int64_t q_head_stride, const float scale) {
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t); constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); // static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES; constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
constexpr int64_t head_elem_num_pre_block = constexpr int64_t head_elem_num_pre_block =
AMX_TILE_ROW_BYTES / sizeof(scalar_t); AMX_TILE_ROW_BYTES / sizeof(scalar_t);
......
...@@ -264,7 +264,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> { ...@@ -264,7 +264,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
constexpr static ISA ISAType = ISA::NEON; constexpr static ISA ISAType = ISA::NEON;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer constexpr static bool scale_on_logits = false; // apply scale on q_buffer
static_assert(HeadDim % HeadDimAlignment == 0); // static_assert(HeadDim % HeadDimAlignment == 0);
// the gemm micro kernel is Mx8 // the gemm micro kernel is Mx8
static_assert(HeadDimAlignment % 8 == 0); static_assert(HeadDimAlignment % 8 == 0);
static_assert(BlockSizeAlignment % 8 == 0); static_assert(BlockSizeAlignment % 8 == 0);
......
...@@ -42,7 +42,7 @@ class CPUAttentionBackend(AttentionBackend): ...@@ -42,7 +42,7 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -137,7 +137,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata] ...@@ -137,7 +137,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
if self.window_size is None: if self.window_size is None:
self.window_size = -1 self.window_size = -1
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size) self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim)
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
def build( def build(
...@@ -484,7 +484,11 @@ def _make_sliding_window_bias( ...@@ -484,7 +484,11 @@ def _make_sliding_window_bias(
return attn_biases return attn_biases
def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str: def _get_attn_isa(
dtype: torch.dtype, block_size: int, head_size: int | None = None
) -> str:
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported() supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx" return "amx"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment