Unverified Commit 93f3c8e5 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Misc] Add `float16` to `CacheDType` (#37199)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 2cc26c3a
...@@ -164,18 +164,18 @@ Priority is **1 = highest** (tried first). ...@@ -164,18 +164,18 @@ Priority is **1 = highest** (tried first).
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. | | Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | | ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | | `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | | `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
> >
...@@ -204,14 +204,14 @@ configuration. ...@@ -204,14 +204,14 @@ configuration.
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. | | Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------ | --------- | ----------- | ---------- | ---- | ------ | --------- | --- | --------------- | ------------ | | ------- | ------ | --------- | ----------- | ---------- | ---- | ------ | --------- | --- | --------------- | ------------ |
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | | `CUTLASS_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
...@@ -13,6 +13,7 @@ logger = init_logger(__name__) ...@@ -13,6 +13,7 @@ logger = init_logger(__name__)
CacheDType = Literal[ CacheDType = Literal[
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -51,7 +51,11 @@ class AttentionBackend(ABC): ...@@ -51,7 +51,11 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph. # makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
"auto",
"float16",
"bfloat16",
]
# Does attention's forward() include kv cache update? # Does attention's forward() include kv cache update?
forward_includes_kv_cache_update: bool = True forward_includes_kv_cache_update: bool = True
......
...@@ -64,6 +64,11 @@ logger = init_logger(__name__) ...@@ -64,6 +64,11 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"float16",
"bfloat16",
]
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
...@@ -164,7 +169,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -164,7 +169,7 @@ class FlashAttentionBackend(AttentionBackend):
return True return True
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype.startswith("fp8"):
return flash_attn_supports_fp8() return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto", "bfloat16"] return kv_cache_dtype in ["auto", "float16", "bfloat16"]
@classmethod @classmethod
def supports_sink(cls) -> bool: def supports_sink(cls) -> bool:
......
...@@ -291,6 +291,7 @@ class FlashInferBackend(AttentionBackend): ...@@ -291,6 +291,7 @@ class FlashInferBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -80,7 +80,11 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -80,7 +80,11 @@ class FlexAttentionBackend(AttentionBackend):
torch.bfloat16, torch.bfloat16,
torch.float32, torch.float32,
] ]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"float16",
"bfloat16",
]
forward_includes_kv_cache_update: bool = False forward_includes_kv_cache_update: bool = False
......
...@@ -39,6 +39,7 @@ class CutlassMLABackend(MLACommonBackend): ...@@ -39,6 +39,7 @@ class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -46,6 +46,7 @@ class FlashAttnMLABackend(MLACommonBackend): ...@@ -46,6 +46,7 @@ class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
] ]
......
...@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend): ...@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -62,6 +62,7 @@ class FlashInferMLASparseBackend(AttentionBackend): ...@@ -62,6 +62,7 @@ class FlashInferMLASparseBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -49,6 +49,7 @@ class FlashMLABackend(MLACommonBackend): ...@@ -49,6 +49,7 @@ class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -26,6 +26,7 @@ class AiterMLABackend(MLACommonBackend): ...@@ -26,6 +26,7 @@ class AiterMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -82,6 +82,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend): ...@@ -82,6 +82,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
] ]
......
...@@ -31,6 +31,7 @@ class TritonMLABackend(MLACommonBackend): ...@@ -31,6 +31,7 @@ class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -38,6 +38,7 @@ class XPUMLASparseBackend(AttentionBackend): ...@@ -38,6 +38,7 @@ class XPUMLASparseBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
] ]
......
...@@ -736,6 +736,7 @@ class AiterFlashAttentionBackend(AttentionBackend): ...@@ -736,6 +736,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -166,6 +166,7 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -166,6 +166,7 @@ class RocmAttentionBackend(AttentionBackend):
] ]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
...@@ -31,6 +32,11 @@ logger = init_logger(__name__) ...@@ -31,6 +32,11 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend): class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"float16",
"bfloat16",
]
forward_includes_kv_cache_update: bool = False forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
......
...@@ -263,6 +263,7 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -263,6 +263,7 @@ class TritonAttentionBackend(AttentionBackend):
] ]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"float16",
"bfloat16", "bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment