Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc1d8be3
Unverified
Commit
fc1d8be3
authored
Nov 27, 2025
by
Matthew Bonanni
Committed by
GitHub
Nov 27, 2025
Browse files
[Attention] Update attention imports (#29540)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
cd007a53
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
41 additions
and
68 deletions
+41
-68
tests/v1/attention/test_rocm_attention_backends_selection.py
tests/v1/attention/test_rocm_attention_backends_selection.py
+3
-6
tests/v1/kv_connector/unit/test_backwards_compatibility.py
tests/v1/kv_connector/unit/test_backwards_compatibility.py
+3
-3
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+4
-7
vllm/attention/layers/chunked_local_attention.py
vllm/attention/layers/chunked_local_attention.py
+1
-2
vllm/config/model.py
vllm/config/model.py
+1
-2
vllm/config/multimodal.py
vllm/config/multimodal.py
+2
-9
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
...ted/kv_transfer/kv_connector/v1/decode_bench_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
...er/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
...buted/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+2
-3
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+2
-2
vllm/forward_context.py
vllm/forward_context.py
+3
-5
vllm/model_executor/layers/attention_layer_base.py
vllm/model_executor/layers/attention_layer_base.py
+2
-5
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+2
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-2
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-3
No files found.
tests/v1/attention/test_rocm_attention_backends_selection.py
View file @
fc1d8be3
...
...
@@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
import
importlib
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
_Backend
importlib
.
reload
(
envs
)
# Convert string backend to enum if provided
backend_enum
=
None
if
selected_backend
:
backend_enum
=
getattr
(
_
Backend
,
selected_backend
)
backend_enum
=
getattr
(
Attention
Backend
Enum
,
selected_backend
)
# Get the backend class path
from
vllm.platforms.rocm
import
RocmPlatform
...
...
@@ -253,7 +252,6 @@ def test_mla_backend_selection(
import
importlib
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
_Backend
importlib
.
reload
(
envs
)
...
...
@@ -269,7 +267,7 @@ def test_mla_backend_selection(
# Convert string backend to enum if provided
backend_enum
=
None
if
selected_backend
:
backend_enum
=
getattr
(
_
Backend
,
selected_backend
)
backend_enum
=
getattr
(
Attention
Backend
Enum
,
selected_backend
)
from
vllm.platforms.rocm
import
RocmPlatform
...
...
@@ -301,7 +299,6 @@ def test_mla_backend_selection(
def
test_aiter_fa_requires_gfx9
(
mock_vllm_config
):
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
from
vllm.attention.backends.registry
import
_Backend
from
vllm.platforms.rocm
import
RocmPlatform
# Mock on_gfx9 to return False
...
...
@@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
),
):
RocmPlatform
.
get_attn_backend_cls
(
selected_backend
=
_
Backend
.
ROCM_AITER_FA
,
selected_backend
=
Attention
Backend
Enum
.
ROCM_AITER_FA
,
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"auto"
,
...
...
tests/v1/kv_connector/unit/test_backwards_compatibility.py
View file @
fc1d8be3
...
...
@@ -14,6 +14,7 @@ from unittest.mock import patch
import
pytest
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
...
...
@@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
.utils
import
create_scheduler
,
create_vllm_config
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
...
...
@@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
None
:
pass
...
...
@@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
None
:
pass
...
...
vllm/attention/backends/abstract.py
View file @
fc1d8be3
...
...
@@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import
torch
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
QuantKey
if
TYPE_CHECKING
:
from
vllm.config.cache
import
CacheDType
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
QuantKey
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.v1.attention.backends.utils
import
KVCacheLayoutType
...
...
@@ -178,8 +177,6 @@ class AttentionBackend(ABC):
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from
vllm.attention.backends.abstract
import
AttentionType
return
attn_type
==
AttentionType
.
DECODER
@
classmethod
...
...
@@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
fused_output_quant_supported
(
self
,
quant_key
:
QuantKey
):
def
fused_output_quant_supported
(
self
,
quant_key
:
"
QuantKey
"
):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
...
...
@@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
qk_rope_head_dim
:
int
,
qk_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
kv_b_proj
:
"
ColumnParallelLinear
"
,
indexer
:
object
|
None
=
None
,
)
->
None
:
raise
NotImplementedError
...
...
vllm/attention/layers/chunked_local_attention.py
View file @
fc1d8be3
...
...
@@ -5,6 +5,7 @@ import functools
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config.vllm
import
VllmConfig
...
...
@@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec
,
)
from
..layer
import
Attention
@
functools
.
lru_cache
def
create_chunked_local_attention_backend
(
...
...
vllm/config/model.py
View file @
fc1d8be3
...
...
@@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from
transformers.configuration_utils
import
ALLOWED_LAYER_TYPES
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config.multimodal
import
MMCacheType
,
MMEncoderTPMode
,
MultiModalConfig
from
vllm.config.pooler
import
PoolerConfig
from
vllm.config.scheduler
import
RunnerType
...
...
@@ -53,7 +54,6 @@ if TYPE_CHECKING:
import
vllm.model_executor.layers.quantization
as
me_quant
import
vllm.model_executor.models
as
me_models
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config.load
import
LoadConfig
from
vllm.config.parallel
import
ParallelConfig
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
...
@@ -61,7 +61,6 @@ if TYPE_CHECKING:
else
:
PretrainedConfig
=
Any
AttentionBackendEnum
=
Any
me_quant
=
LazyLoader
(
"model_executor"
,
globals
(),
"vllm.model_executor.layers.quantization"
)
...
...
vllm/config/multimodal.py
View file @
fc1d8be3
...
...
@@ -2,19 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Mapping
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
TypeAlias
from
typing
import
Any
,
Literal
,
TypeAlias
from
pydantic
import
ConfigDict
,
Field
,
field_validator
,
model_validator
from
pydantic.dataclasses
import
dataclass
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config.utils
import
config
from
vllm.utils.hashing
import
safe_hash
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
AttentionBackendEnum
else
:
AttentionBackendEnum
=
Any
@
dataclass
class
BaseDummyOptions
:
...
...
@@ -170,9 +166,6 @@ class MultiModalConfig:
def
_validate_mm_encoder_attn_backend
(
cls
,
value
:
str
|
AttentionBackendEnum
|
None
)
->
AttentionBackendEnum
|
None
:
# We need to import the real type here (deferred to avoid circular import).
from
vllm.attention.backends.registry
import
AttentionBackendEnum
if
isinstance
(
value
,
str
)
and
value
.
upper
()
==
"XFORMERS"
:
raise
ValueError
(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
fc1d8be3
...
...
@@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
...
...
@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
"
AttentionBackend
"
]
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
"""
Initialize with a single KV cache tensor used by all layers.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
View file @
fc1d8be3
...
...
@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
,
...
...
@@ -45,7 +46,6 @@ from vllm.logger import init_logger
from
vllm.utils.math_utils
import
cdiv
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
...
...
@@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
Any
,
)
->
None
:
# This connector doesn't save KV cache (benchmarking only)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
fc1d8be3
...
...
@@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
...
...
@@ -17,7 +18,6 @@ from vllm.logger import init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
...
@@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
Any
,
)
->
None
:
"""
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
View file @
fc1d8be3
...
...
@@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
from
lmcache.v1.offload_server.zmq_server
import
ZMQOffloadServer
from
lmcache.v1.plugin.plugin_launcher
import
PluginLauncher
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
...
...
@@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.version
import
__version__
as
VLLM_VERSION
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
...
...
@@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
None
:
"""Start saving the a layer of KV cache from vLLM's paged buffer
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
View file @
fc1d8be3
...
...
@@ -10,6 +10,7 @@ import zmq
from
lmcache.integration.vllm.utils
import
mla_enabled
from
lmcache.utils
import
init_logger
as
lmcache_init_logger
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
...
...
@@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
...
...
@@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
Any
,
)
->
None
:
"""
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
fc1d8be3
...
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
...
...
@@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
...
...
@@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
None
:
for
c
in
self
.
_connectors
:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
fc1d8be3
...
...
@@ -20,7 +20,7 @@ import torch
import
zmq
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
...
...
@@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
...
...
@@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
None
:
"""NixlConnector does not save explicitly."""
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
fc1d8be3
...
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
regex
as
re
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
...
...
@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
...
@@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
Any
,
)
->
None
:
"""Start saving the KV cache of the layer from vLLM's paged buffer
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
fc1d8be3
...
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
safetensors
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
...
...
@@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
...
@@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"
AttentionMetadata
"
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
Any
,
)
->
None
:
"""Start saving the KV cache of the layer from vLLM's paged buffer
...
...
vllm/forward_context.py
View file @
fc1d8be3
...
...
@@ -5,19 +5,17 @@ import time
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
from
typing
import
Any
,
NamedTuple
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CUDAGraphMode
,
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.ubatch_utils
import
UBatchSlices
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
logger
=
init_logger
(
__name__
)
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
...
...
@@ -195,7 +193,7 @@ class ForwardContext:
for each microbatch.
Set dynamically for each forward pass
"""
attn_metadata
:
dict
[
str
,
"
AttentionMetadata
"
]
|
list
[
dict
[
str
,
"
AttentionMetadata
"
]]
attn_metadata
:
dict
[
str
,
AttentionMetadata
]
|
list
[
dict
[
str
,
AttentionMetadata
]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
# set dynamically for each forward pass
...
...
vllm/model_executor/layers/attention_layer_base.py
View file @
fc1d8be3
...
...
@@ -3,14 +3,11 @@
"""Base class for attention-like layers."""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
AttentionLayerBase
(
ABC
):
"""
...
...
@@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
"""
@
abstractmethod
def
get_attn_backend
(
self
)
->
type
[
"
AttentionBackend
"
]:
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
"""Get the attention backend class for this layer."""
pass
...
...
vllm/model_executor/layers/mamba/abstract.py
View file @
fc1d8be3
...
...
@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.selector
import
get_mamba_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
,
MambaSpec
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
MambaBase
(
AttentionLayerBase
):
"""
...
...
@@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
),
)
def
get_attn_backend
(
self
)
->
type
[
"
AttentionBackend
"
]:
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
"""Get the attention backend class for this Mamba layer."""
return
get_mamba_attn_backend
(
self
.
mamba_type
)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
fc1d8be3
...
...
@@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from
compressed_tensors.transform
import
TransformConfig
import
vllm.envs
as
envs
from
vllm.attention.layer
import
Attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
...
...
@@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
# collect schemes
quant_scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
fc1d8be3
...
...
@@ -14,6 +14,7 @@ import vllm.envs as envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.attention.layer
import
Attention
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
...
...
@@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def
get_xpu_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
from
vllm.model_executor.layers.quantization.ipex_quant
import
(
XPUFp8LinearMethod
,
XPUFp8MoEMethod
,
...
...
@@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
current_platform
.
is_xpu
():
return
self
.
get_xpu_quant_method
(
layer
,
prefix
)
if
isinstance
(
layer
,
LinearBase
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment