"vscode:/vscode.git/clone" did not exist on "1b2c440cd646dac290b535b86be89d22fbdbeab9"
Unverified Commit 1174723e authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Fix TURBOQUANT backend selection in cuda.py (#40060)


Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 6b2b7bd0
...@@ -106,6 +106,7 @@ Priority is **1 = highest** (tried first). ...@@ -106,6 +106,7 @@ Priority is **1 = highest** (tried first).
| 2 | `FLASH_ATTN` | | 2 | `FLASH_ATTN` |
| 3 | `TRITON_ATTN` | | 3 | `TRITON_ATTN` |
| 4 | `FLEX_ATTENTION` | | 4 | `FLEX_ATTENTION` |
| 5 | `TURBOQUANT` |
**Ampere/Hopper (SM 8.x-9.x):** **Ampere/Hopper (SM 8.x-9.x):**
...@@ -115,6 +116,7 @@ Priority is **1 = highest** (tried first). ...@@ -115,6 +116,7 @@ Priority is **1 = highest** (tried first).
| 2 | `FLASHINFER` | | 2 | `FLASHINFER` |
| 3 | `TRITON_ATTN` | | 3 | `TRITON_ATTN` |
| 4 | `FLEX_ATTENTION` | | 4 | `FLEX_ATTENTION` |
| 5 | `TURBOQUANT` |
### MLA Attention (DeepSeek-style) ### MLA Attention (DeepSeek-style)
......
...@@ -131,6 +131,7 @@ def _get_backend_priorities( ...@@ -131,6 +131,7 @@ def _get_backend_priorities(
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TURBOQUANT,
] ]
else: else:
return [ return [
...@@ -138,6 +139,7 @@ def _get_backend_priorities( ...@@ -138,6 +139,7 @@ def _get_backend_priorities(
AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TURBOQUANT,
] ]
...@@ -255,11 +257,6 @@ class CudaPlatformBase(Platform): ...@@ -255,11 +257,6 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = [] valid_backends_priorities = []
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"):
return [(AttentionBackendEnum.TURBOQUANT, 0)], {}
backend_priorities = _get_backend_priorities( backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, attn_selector_config.use_mla,
device_capability, device_capability,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment