Unverified Commit 36d7f198 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Support head_size 512 in cpu_attn (#38676)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent 2d725b89
......@@ -8,7 +8,7 @@ Generate CPU attention dispatch switch cases and kernel instantiations.
import os
# Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256, 512]
# Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16 = [80, 112]
......
......@@ -165,7 +165,7 @@ Priority is **1 = highest** (tried first).
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
......
......@@ -25,7 +25,7 @@ NUM_HEADS = [
(8, 2),
(9, 3),
]
HEAD_SIZES = [96, 128]
HEAD_SIZES = [96, 128, 512]
HEAD_SIZES_VEC16 = [96, 80, 112, 128]
QTYPES = [torch.bfloat16, torch.half, torch.float32]
SLIDING_WINDOWS = [None, 256]
......
......@@ -38,7 +38,7 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512]
@staticmethod
def get_name() -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment