Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77431529
Unverified
Commit
77431529
authored
Feb 17, 2026
by
Matthew Bonanni
Committed by
GitHub
Feb 17, 2026
Browse files
[Attention] Refactor `check_and_update_config` (#33600)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
ab33d2a6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
270 additions
and
172 deletions
+270
-172
vllm/config/cache.py
vllm/config/cache.py
+4
-7
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+253
-156
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+12
-7
No files found.
vllm/config/cache.py
View file @
77431529
...
@@ -19,7 +19,6 @@ else:
...
@@ -19,7 +19,6 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
BlockSize
=
Literal
[
1
,
8
,
16
,
32
,
64
,
128
,
256
]
CacheDType
=
Literal
[
CacheDType
=
Literal
[
"auto"
,
"auto"
,
"bfloat16"
,
"bfloat16"
,
...
@@ -39,13 +38,11 @@ KVOffloadingBackend = Literal["native", "lmcache"]
...
@@ -39,13 +38,11 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class
CacheConfig
:
class
CacheConfig
:
"""Configuration for the KV cache."""
"""Configuration for the KV cache."""
block_size
:
SkipValidation
[
BlockSize
]
=
None
# type: ignore[assignment]
block_size
:
SkipValidation
[
int
]
=
None
# type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
"""Size of a contiguous cache block in number of tokens.
only block sizes up to 32 are supported.
This config has no static default. If left unspecified by the user, it will
This is None until `Platform.check_and_update_config()` sets it based on
be set in `Platform.check_and_update_config()` based on the current
the current platform. Always an int by the time the engine starts."""
platform."""
gpu_memory_utilization
:
float
=
Field
(
default
=
0.9
,
gt
=
0
,
le
=
1
)
gpu_memory_utilization
:
float
=
Field
(
default
=
0.9
,
gt
=
0
,
le
=
1
)
"""The fraction of GPU memory to be used for the model executor, which can
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
...
...
vllm/engine/arg_utils.py
View file @
77431529
...
@@ -59,7 +59,6 @@ from vllm.config import (
...
@@ -59,7 +59,6 @@ from vllm.config import (
get_attr_docs
,
get_attr_docs
,
)
)
from
vllm.config.cache
import
(
from
vllm.config.cache
import
(
BlockSize
,
CacheDType
,
CacheDType
,
KVOffloadingBackend
,
KVOffloadingBackend
,
MambaCacheMode
,
MambaCacheMode
,
...
@@ -431,7 +430,7 @@ class EngineArgs:
...
@@ -431,7 +430,7 @@ class EngineArgs:
max_parallel_loading_workers
:
int
|
None
=
(
max_parallel_loading_workers
:
int
|
None
=
(
ParallelConfig
.
max_parallel_loading_workers
ParallelConfig
.
max_parallel_loading_workers
)
)
block_size
:
BlockSize
=
CacheConfig
.
block_size
block_size
:
int
=
None
# type: ignore[assignment]
enable_prefix_caching
:
bool
|
None
=
None
enable_prefix_caching
:
bool
|
None
=
None
prefix_caching_hash_algo
:
PrefixCachingHashAlgo
=
(
prefix_caching_hash_algo
:
PrefixCachingHashAlgo
=
(
CacheConfig
.
prefix_caching_hash_algo
CacheConfig
.
prefix_caching_hash_algo
...
...
vllm/platforms/cuda.py
View file @
77431529
...
@@ -163,8 +163,6 @@ class CudaPlatformBase(Platform):
...
@@ -163,8 +163,6 @@ class CudaPlatformBase(Platform):
@
classmethod
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
def
check_and_update_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
...
@@ -172,112 +170,19 @@ class CudaPlatformBase(Platform):
...
@@ -172,112 +170,19 @@ class CudaPlatformBase(Platform):
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
if
cache_config
and
cache_config
.
block_size
is
None
:
user_specified_block_size
=
cache_config
.
block_size
is
not
None
if
not
user_specified_block_size
:
cache_config
.
block_size
=
16
cache_config
.
block_size
=
16
# TODO(lucas): handle this more gracefully
# Ensure block_size is compatible with the attention backend.
# Note: model_config may be None during testing
# Note: model_config may be None during testing.
# Note: block_size is initialized in
# Skip hybrid (attention+mamba) models — their block_size is
# HybridAttentionMambaModelConfig.verify_and_update_config
# managed by HybridAttentionMambaModelConfig
# for models with both attention and mamba,
if
model_config
is
not
None
and
not
model_config
.
is_hybrid
:
# and doesn't need to be reinitialized here
cls
.
_update_block_size_for_backend
(
if
(
vllm_config
,
model_config
is
not
None
user_specified_block_size
,
and
model_config
.
use_mla
)
and
cache_config
.
block_size
is
not
None
):
use_sparse
=
hasattr
(
vllm_config
.
model_config
.
hf_config
,
"index_topk"
)
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla
=
False
use_cutlass_mla
=
False
use_flashinfer_mla
=
False
use_flashmla_sparse
=
False
use_flashinfer_mla_sparse
=
False
from
vllm.v1.attention.ops.flashmla
import
is_flashmla_dense_supported
if
vllm_config
.
attention_config
.
backend
is
None
:
# Default case
hf_text_config
=
model_config
.
hf_text_config
qk_nope_head_dim
=
getattr
(
hf_text_config
,
"qk_nope_head_dim"
,
1
)
if
(
cls
.
is_device_capability_family
(
100
)
and
not
use_sparse
and
qk_nope_head_dim
==
128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla
=
True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config
.
attention_config
.
backend
=
(
AttentionBackendEnum
.
FLASHINFER_MLA
)
elif
cls
.
is_device_capability_family
(
100
)
and
not
use_sparse
:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla
=
True
elif
is_flashmla_dense_supported
()[
0
]:
# Non-Blackwell with FlashMLA support
use_flashmla
=
True
else
:
# Fallback: will use Triton MLA or other compatible backend
pass
else
:
# Forced case
backend
=
vllm_config
.
attention_config
.
backend
use_flashmla
=
backend
==
AttentionBackendEnum
.
FLASHMLA
use_cutlass_mla
=
backend
==
AttentionBackendEnum
.
CUTLASS_MLA
use_flashinfer_mla
=
backend
==
AttentionBackendEnum
.
FLASHINFER_MLA
use_flashmla_sparse
=
backend
==
AttentionBackendEnum
.
FLASHMLA_SPARSE
use_flashinfer_mla_sparse
=
(
backend
==
AttentionBackendEnum
.
FLASHINFER_MLA_SPARSE
)
if
(
use_flashmla
and
is_flashmla_dense_supported
()[
0
]
and
cache_config
.
block_size
%
64
!=
0
):
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashMLA backend."
)
if
use_cutlass_mla
and
cache_config
.
block_size
%
128
!=
0
:
cache_config
.
block_size
=
128
logger
.
info
(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if
(
use_flashinfer_mla
and
cache_config
.
block_size
!=
32
and
cache_config
.
block_size
%
64
!=
0
):
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
)
if
use_sparse
:
if
not
(
use_flashmla_sparse
or
use_flashinfer_mla_sparse
):
use_flashmla_sparse
=
True
if
use_flashmla_sparse
and
cache_config
.
block_size
!=
64
:
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif
use_flashinfer_mla_sparse
and
cache_config
.
block_size
not
in
(
32
,
64
,
):
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
)
scheduler_config
=
vllm_config
.
scheduler_config
scheduler_config
=
vllm_config
.
scheduler_config
# Note: model_config may be None during testing
# Note: model_config may be None during testing
...
@@ -293,6 +198,150 @@ class CudaPlatformBase(Platform):
...
@@ -293,6 +198,150 @@ class CudaPlatformBase(Platform):
)
)
scheduler_config
.
disable_chunked_mm_input
=
True
scheduler_config
.
disable_chunked_mm_input
=
True
@
classmethod
def
_update_block_size_for_backend
(
cls
,
vllm_config
:
"VllmConfig"
,
user_specified_block_size
:
bool
,
)
->
None
:
"""Ensure block_size is compatible with the attention backend.
If the user specified --block-size, the selector validates/filters
backends by that block size (raising on incompatibility). Otherwise,
the backend is selected unconstrained and block_size is set to the
backend's preferred value.
"""
from
vllm.config.vllm
import
set_current_vllm_config
from
vllm.v1.attention.selector
import
AttentionSelectorConfig
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
device_capability
=
cls
.
get_device_capability
()
if
device_capability
is
None
:
return
use_mla
=
model_config
.
use_mla
attn_selector_config
=
AttentionSelectorConfig
(
head_size
=
model_config
.
get_head_size
(),
dtype
=
model_config
.
dtype
,
# type: ignore[arg-type]
kv_cache_dtype
=
cache_config
.
cache_dtype
,
block_size
=
cache_config
.
block_size
if
user_specified_block_size
else
None
,
use_mla
=
use_mla
,
has_sink
=
False
,
use_sparse
=
use_mla
and
hasattr
(
model_config
.
hf_config
,
"index_topk"
),
use_mm_prefix
=
model_config
.
is_mm_prefix_lm
,
)
user_specified_backend
=
vllm_config
.
attention_config
.
backend
num_heads
=
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
,
)
with
set_current_vllm_config
(
vllm_config
):
chosen_backend
=
cls
.
select_attention_backend
(
selected_backend
=
user_specified_backend
,
attn_selector_config
=
attn_selector_config
,
device_capability
=
device_capability
,
# Don't raise here — we produce better errors below.
raise_on_invalid
=
False
,
num_heads
=
num_heads
,
)
# If the user's --block-size forced a non-optimal backend,
# warn them. Only relevant when the user didn't also specify
# --attention-backend (in which case the choice is explicit).
if
(
chosen_backend
is
not
None
and
user_specified_block_size
and
user_specified_backend
is
None
):
optimal
=
cls
.
select_attention_backend
(
selected_backend
=
None
,
attn_selector_config
=
attn_selector_config
.
_replace
(
block_size
=
None
,
),
device_capability
=
device_capability
,
raise_on_invalid
=
False
,
num_heads
=
num_heads
,
)
if
optimal
is
not
None
and
optimal
!=
chosen_backend
:
logger
.
warning
(
"--block-size %d is not supported by the preferred "
"%s backend. Using %s instead, which may result "
"in reduced performance. Consider removing "
"--block-size to auto-select the optimal "
"block size."
,
cache_config
.
block_size
,
optimal
.
name
,
chosen_backend
.
name
,
)
if
chosen_backend
is
not
None
:
if
user_specified_block_size
:
# User's block_size is compatible with the chosen
# backend.
return
# User didn't specify --block-size, so auto-select the
# preferred block size for the chosen backend.
try
:
backend_class
=
chosen_backend
.
get_class
()
except
ImportError
:
return
# Will fail later with a better error
preferred
=
backend_class
.
get_preferred_block_size
(
cache_config
.
block_size
,
)
if
cache_config
.
block_size
!=
preferred
:
logger
.
info
(
"Setting kv cache block size to %d for %s backend."
,
preferred
,
chosen_backend
.
name
,
)
cache_config
.
block_size
=
preferred
return
# No valid backend found. If the user didn't constrain the
# selection, defer the error to get_attn_backend_cls where
# the full config (including per-layer settings) is
# available.
if
not
user_specified_block_size
:
return
if
user_specified_backend
is
not
None
:
# User specified --block-size and --attention-backend
# and they are incompatible.
try
:
backend_class
=
user_specified_backend
.
get_class
()
supported
=
backend_class
.
get_supported_kernel_block_sizes
()
except
ImportError
:
supported
=
None
raise
ValueError
(
f
"User-specified --block-size "
f
"
{
cache_config
.
block_size
}
is incompatible with "
f
"the specified --attention-backend "
f
"
{
user_specified_backend
.
name
}
(supported kernel "
f
"block sizes:
{
supported
}
). Either remove "
f
"--block-size to auto-select, or choose a "
f
"compatible value."
)
else
:
# User specified --block-size but no backend supports
# it.
_
,
invalid_reasons
=
cls
.
get_valid_backends
(
device_capability
=
device_capability
,
attn_selector_config
=
attn_selector_config
,
num_heads
=
num_heads
,
)
reasons_str
=
", "
.
join
(
f
"
{
b
.
name
}
: [
{
', '
.
join
(
r
)
}
]"
for
b
,
r
in
invalid_reasons
.
items
()
)
raise
ValueError
(
f
"No valid attention backend found for "
f
"--block-size
{
cache_config
.
block_size
}
. "
f
"Reasons: {{
{
reasons_str
}
}}. Either remove "
f
"--block-size to auto-select, or choose a "
f
"compatible value."
)
@
classmethod
@
classmethod
def
get_current_memory_usage
(
def
get_current_memory_usage
(
cls
,
device
:
torch
.
types
.
Device
|
None
=
None
cls
,
device
:
torch
.
types
.
Device
|
None
=
None
...
@@ -336,77 +385,125 @@ class CudaPlatformBase(Platform):
...
@@ -336,77 +385,125 @@ class CudaPlatformBase(Platform):
return
valid_backends_priorities
,
invalid_reasons
return
valid_backends_priorities
,
invalid_reasons
@
classmethod
@
classmethod
def
get_att
n_backend
_cls
(
def
select_attentio
n_backend
(
cls
,
cls
,
selected_backend
:
"AttentionBackendEnum"
,
selected_backend
:
"AttentionBackendEnum
| None
"
,
attn_selector_config
:
"AttentionSelectorConfig"
,
attn_selector_config
:
"AttentionSelectorConfig"
,
device_capability
:
"DeviceCapability"
,
raise_on_invalid
:
bool
=
True
,
num_heads
:
int
|
None
=
None
,
num_heads
:
int
|
None
=
None
,
)
->
str
:
)
->
"AttentionBackendEnum | None"
:
device_capability
=
cls
.
get_device_capability
()
"""Select the best attention backend for the given configuration.
assert
device_capability
is
not
None
Args:
attn_selector_config
=
attn_selector_config
.
_replace
(
block_size
=
None
)
selected_backend: User-specified backend, or None for auto-selection
attn_selector_config: Configuration for attention selection
device_capability: Device capability info
raise_on_invalid: If True, raise ValueError when no valid backend
num_heads: Number of attention heads per GPU, used for backend
priority ordering on Blackwell GPUs
Returns:
The selected backend enum, or None if no valid backend found
and raise_on_invalid is False
"""
# First try checking just the selected backend, if there is one.
# First try checking just the selected backend, if there is one.
if
selected_backend
is
not
None
:
if
selected_backend
is
not
None
:
try
:
try
:
backend_class
=
selected_backend
.
get_class
()
backend_class
=
selected_backend
.
get_class
()
in
valid
_reason
s
=
backend_class
.
validate_configuration
(
valid
ation_error
s
=
backend_class
.
validate_configuration
(
device_capability
=
device_capability
,
device_capability
=
device_capability
,
**
attn_selector_config
.
_asdict
(),
**
attn_selector_config
.
_asdict
(),
)
)
except
ImportError
:
except
ImportError
:
in
valid
_reason
s
=
[
"ImportError"
]
valid
ation_error
s
=
[
"ImportError"
]
if
in
valid
_reason
s
:
if
valid
ation_error
s
:
raise
ValueError
(
if
raise
_on_invalid
:
f
"Selected backend
{
selected_backend
}
is not valid for "
raise
ValueError
(
f
"this configuration. Reason:
{
invalid_reasons
}
"
f
"Selected backend
{
selected_backend
}
is not valid for
"
)
f
"this configuration. Reason:
{
validation_errors
}
"
else
:
)
logger
.
info
(
"Using %s backend."
,
selected_backend
)
return
None
return
selected_backend
.
get_path
()
return
selected_backend
# No selected backend or the selected backend is invalid,
# No selected backend, so find the best valid one.
# so we try finding a valid backend.
valid_backends_priorities
,
invalid_reasons
=
cls
.
get_valid_backends
(
valid_backends_priorities
,
invalid_reasons
=
cls
.
get_valid_backends
(
device_capability
=
device_capability
,
device_capability
=
device_capability
,
attn_selector_config
=
attn_selector_config
,
attn_selector_config
=
attn_selector_config
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
)
)
reasons_str
=
(
"{"
+
", "
.
join
(
f
"
{
backend
.
name
}
: [
{
', '
.
join
(
reasons
)
}
]"
for
backend
,
reasons
in
invalid_reasons
.
items
()
)
+
"}"
)
config_str
=
attn_selector_config
.
__repr__
()
logger
.
debug_once
(
f
"Some attention backends are not valid for
{
cls
.
device_name
}
with "
f
"
{
config_str
}
. Reasons:
{
reasons_str
}
."
)
if
len
(
valid_backends_priorities
)
==
0
:
if
len
(
valid_backends_priorities
)
==
0
:
raise
ValueError
(
if
raise_on_invalid
:
f
"No valid attention backend found for
{
cls
.
device_name
}
"
reasons_str
=
(
f
"with
{
config_str
}
. Reasons:
{
reasons_str
}
."
"{"
)
+
", "
.
join
(
f
"
{
backend
.
name
}
: [
{
', '
.
join
(
reasons
)
}
]"
for
backend
,
reasons
in
invalid_reasons
.
items
()
)
+
"}"
)
config_str
=
attn_selector_config
.
__repr__
()
raise
ValueError
(
f
"No valid attention backend found for
{
cls
.
device_name
}
"
f
"with
{
config_str
}
. Reasons:
{
reasons_str
}
."
)
return
None
# We have found some valid backends. Select the one with the
# Select the one with the highest priority (lowest index).
# highest priority.
sorted_backends
=
sorted
(
valid_backends_priorities
,
key
=
lambda
x
:
x
[
1
])
sorted_indices
=
sorted
(
return
sorted_backends
[
0
][
0
]
range
(
len
(
valid_backends_priorities
)),
key
=
lambda
i
:
valid_backends_priorities
[
i
][
1
],
@
classmethod
)
def
get_attn_backend_cls
(
selected_index
=
sorted_indices
[
0
]
cls
,
selected_backend
=
valid_backends_priorities
[
selected_index
][
0
]
selected_backend
:
"AttentionBackendEnum | None"
,
logger
.
info_once
(
attn_selector_config
:
"AttentionSelectorConfig"
,
"Using %s attention backend out of potential backends: %s."
,
num_heads
:
int
|
None
=
None
,
selected_backend
.
name
,
)
->
str
:
"["
+
", "
.
join
(
f
"'
{
b
[
0
].
name
}
'"
for
b
in
valid_backends_priorities
)
+
"]"
,
device_capability
=
cls
.
get_device_capability
()
scope
=
"local"
,
assert
device_capability
is
not
None
chosen_backend
=
cls
.
select_attention_backend
(
selected_backend
=
selected_backend
,
attn_selector_config
=
attn_selector_config
,
num_heads
=
num_heads
,
device_capability
=
device_capability
,
raise_on_invalid
=
True
,
)
)
assert
chosen_backend
is
not
None
# raise_on_invalid=True guarantees this
# Log the selection
if
selected_backend
is
not
None
:
logger
.
info
(
"Using %s backend."
,
chosen_backend
)
else
:
# Get all valid backends for logging
valid_backends_priorities
,
invalid_reasons
=
cls
.
get_valid_backends
(
device_capability
=
device_capability
,
attn_selector_config
=
attn_selector_config
,
num_heads
=
num_heads
,
)
reasons_str
=
(
"{"
+
", "
.
join
(
f
"
{
backend
.
name
}
: [
{
', '
.
join
(
reasons
)
}
]"
for
backend
,
reasons
in
invalid_reasons
.
items
()
)
+
"}"
)
config_str
=
attn_selector_config
.
__repr__
()
logger
.
debug_once
(
f
"Some attention backends are not valid for
{
cls
.
device_name
}
with "
f
"
{
config_str
}
. Reasons:
{
reasons_str
}
."
)
logger
.
info_once
(
"Using %s attention backend out of potential backends: %s"
,
chosen_backend
.
name
,
tuple
(
b
[
0
].
name
for
b
in
valid_backends_priorities
),
scope
=
"local"
,
)
return
selected
_backend
.
get_path
()
return
chosen
_backend
.
get_path
()
@
classmethod
@
classmethod
def
get_supported_vit_attn_backends
(
cls
)
->
list
[
"AttentionBackendEnum"
]:
def
get_supported_vit_attn_backends
(
cls
)
->
list
[
"AttentionBackendEnum"
]:
...
...
vllm/v1/attention/backend.py
View file @
77431529
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
replace
from
dataclasses
import
dataclass
,
replace
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
,
get_args
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
...
@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
@
classmethod
@
classmethod
def
supports_block_size
(
cls
,
block_size
:
int
|
None
)
->
bool
:
def
supports_block_size
(
cls
,
block_size
:
int
|
None
)
->
bool
:
from
vllm.config.cache
import
BlockSize
if
block_size
is
None
:
if
block_size
is
None
:
return
True
return
True
valid_sizes
=
get_args
(
BlockSize
)
if
block_size
not
in
valid_sizes
:
return
False
supported_kernel_block_sizes
=
cls
.
get_supported_kernel_block_sizes
()
supported_kernel_block_sizes
=
cls
.
get_supported_kernel_block_sizes
()
if
not
supported_kernel_block_sizes
:
if
not
supported_kernel_block_sizes
:
return
True
return
True
...
@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
...
@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
return
True
return
True
return
False
return
False
@
classmethod
def
get_preferred_block_size
(
cls
,
default_block_size
:
int
=
16
)
->
int
:
supported_sizes
=
cls
.
get_supported_kernel_block_sizes
()
if
not
supported_sizes
:
return
default_block_size
if
cls
.
supports_block_size
(
default_block_size
):
return
default_block_size
return
min
(
s
.
base
if
isinstance
(
s
,
MultipleOf
)
else
s
for
s
in
supported_sizes
)
@
classmethod
@
classmethod
def
is_mla
(
cls
)
->
bool
:
def
is_mla
(
cls
)
->
bool
:
return
False
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment