Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a608b4c6
Unverified
Commit
a608b4c6
authored
Jan 27, 2026
by
Matthew Bonanni
Committed by
GitHub
Jan 27, 2026
Browse files
[5/N][Attention] Finish eliminating `vllm/attention` folder (#32064)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
1f3a2c29
Changes
151
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
415 additions
and
351 deletions
+415
-351
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+42
-315
vllm/model_executor/layers/attention/chunked_local_attention.py
...odel_executor/layers/attention/chunked_local_attention.py
+1
-1
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+1
-1
vllm/model_executor/layers/attention/encoder_only_attention.py
...model_executor/layers/attention/encoder_only_attention.py
+1
-1
vllm/model_executor/layers/attention/kv_transfer_utils.py
vllm/model_executor/layers/attention/kv_transfer_utils.py
+1
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+354
-17
vllm/model_executor/layers/attention/static_sink_attention.py
.../model_executor/layers/attention/static_sink_attention.py
+1
-1
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+1
-1
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+1
-1
vllm/model_executor/layers/quantization/petit.py
vllm/model_executor/layers/quantization/petit.py
+1
-1
vllm/model_executor/layers/quantization/ptpc_fp8.py
vllm/model_executor/layers/quantization/ptpc_fp8.py
+1
-1
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+1
-1
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/model_executor/models/afmoe.py
vllm/model_executor/models/afmoe.py
+1
-1
vllm/model_executor/models/aimv2.py
vllm/model_executor/models/aimv2.py
+1
-1
vllm/model_executor/models/apertus.py
vllm/model_executor/models/apertus.py
+2
-2
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+1
-1
No files found.
vllm/attention/
layer
.py
→
vllm/
model_executor/layers/
attention/
attention
.py
View file @
a608b4c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from
typing
import
cast
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.attention.utils.kv_transfer_utils
import
maybe_transfer_kv_layer
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config.vllm
import
VllmConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention.kv_transfer_utils
import
(
maybe_transfer_kv_layer
,
)
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
UnquantizedLinearMethod
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -33,20 +32,54 @@ from vllm.utils.torch_utils import (
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionType
,
MLAAttentionImpl
,
)
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
MLAAttentionSpec
,
SlidingWindowSpec
,
)
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.attention
import
MLAAttention
logger
=
init_logger
(
__name__
)
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
error_msg
=
(
f
"Specified KV sharing target layer for
{
current_layer_name
}
"
f
"is not valid: target layer
{
target_layer_name
}
"
)
if
current_layer_name
==
target_layer_name
:
raise
ValueError
(
error_msg
+
"cannot be the same as the current layer."
)
if
target_layer_name
not
in
static_forward_context
:
from
vllm.model_executor.models.utils
import
extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx
=
extract_layer_index
(
current_layer_name
)
target_layer_idx
=
extract_layer_index
(
target_layer_name
)
if
current_layer_idx
<=
target_layer_idx
:
raise
ValueError
(
error_msg
+
"must come before the current layer."
)
else
:
raise
ValueError
(
error_msg
+
"is not a valid Attention layer in the model."
)
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type
=
static_forward_context
[
target_layer_name
].
attn_type
expected
=
static_forward_context
[
current_layer_name
].
attn_type
if
target_layer_attn_type
!=
expected
:
raise
ValueError
(
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
def
should_load_quant_weights
(
quant_method
:
QuantizeMethodBase
|
None
)
->
bool
:
"""Returns whether the quantization method should load quantized weights."""
return
quant_method
is
not
None
and
not
isinstance
(
...
...
@@ -493,236 +526,6 @@ class Attention(nn.Module, AttentionLayerBase):
)
class
MLAAttention
(
nn
.
Module
,
AttentionLayerBase
):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def
__init__
(
self
,
num_heads
:
int
,
scale
:
float
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
|
None
,
kv_lora_rank
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_sparse
:
bool
=
False
,
indexer
:
object
|
None
=
None
,
**
extra_impl_args
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
scale
=
scale
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
head_size
=
kv_lora_rank
+
qk_rope_head_dim
self
.
layer_name
=
prefix
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
calculate_kv_scales
=
False
self
.
quant_config
=
quant_config
# Initialize KV cache quantization attributes
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
)
dtype
=
torch
.
get_default_dtype
()
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_mla
=
True
,
use_sparse
=
use_sparse
,
)
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
or
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
)
):
logger
.
warning_once
(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
self
.
impl
=
impl_cls
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_size
,
scale
=
self
.
scale
,
num_kv_heads
=
1
,
alibi_slopes
=
None
,
sliding_window
=
None
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
logits_soft_cap
=
None
,
attn_type
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
=
None
,
# MLA Args
q_lora_rank
=
self
.
q_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_head_dim
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
v_head_dim
=
self
.
v_head_dim
,
kv_b_proj
=
kv_b_proj
,
indexer
=
indexer
,
**
extra_impl_args
,
)
self
.
use_direct_call
=
not
current_platform
.
opaque_attention_op
()
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
get_current_vllm_config
().
parallel_config
.
pipeline_parallel_size
)
]
self
.
use_sparse
=
use_sparse
# Initialize q/k/v range constants.
self
.
q_range
=
torch
.
tensor
(
envs
.
Q_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
self
.
k_range
=
torch
.
tensor
(
envs
.
K_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
self
.
v_range
=
torch
.
tensor
(
envs
.
V_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
def
forward
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_shape
:
torch
.
Size
|
None
=
None
,
)
->
torch
.
Tensor
:
if
self
.
calculate_kv_scales
:
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
)
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
,
output
=
output
,
)
return
output
else
:
return
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
)
else
:
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch
.
ops
.
vllm
.
unified_mla_attention_with_output
(
q
,
kv_c_normed
,
k_pe
,
output
,
self
.
layer_name
,
)
return
output
else
:
return
torch
.
ops
.
vllm
.
unified_mla_attention
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
,
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
if
hasattr
(
self
.
impl
,
"process_weights_after_loading"
):
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method
=
(
self
.
quant_config
.
get_quant_method
(
self
,
prefix
=
self
.
layer_name
)
if
self
.
quant_config
else
None
)
if
not
should_load_quant_weights
(
quant_method
):
set_default_quant_scales
(
self
,
register_buffer
=
False
)
def
calc_kv_scales
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
)
->
None
:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range
=
getattr
(
self
,
"q_range"
,
torch
.
tensor
(
1.0
))
k_range
=
getattr
(
self
,
"k_range"
,
torch
.
tensor
(
1.0
))
v_range
=
getattr
(
self
,
"v_range"
,
torch
.
tensor
(
1.0
))
self
.
_q_scale
.
copy_
(
torch
.
abs
(
q
).
max
()
/
q_range
)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max
=
torch
.
abs
(
kv_c_normed
).
max
()
self
.
_k_scale
.
copy_
(
kv_abs_max
/
k_range
)
self
.
_v_scale
.
copy_
(
kv_abs_max
/
v_range
)
self
.
_q_scale_float
=
self
.
_q_scale
.
item
()
self
.
_k_scale_float
=
self
.
_k_scale
.
item
()
self
.
_v_scale_float
=
self
.
_v_scale
.
item
()
self
.
calculate_kv_scales
=
False
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
kv_cache_dtype
=
kv_cache_dtype_str_to_dtype
(
self
.
kv_cache_dtype
,
vllm_config
.
model_config
)
return
MLAAttentionSpec
(
block_size
=
vllm_config
.
cache_config
.
block_size
,
num_kv_heads
=
1
,
head_size
=
self
.
head_size
,
dtype
=
kv_cache_dtype
,
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
)
def
maybe_calc_kv_scales
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
...
@@ -759,7 +562,7 @@ direct_register_custom_op(
def
get_attention_context
(
layer_name
:
str
,
)
->
tuple
[
dict
|
object
|
None
,
Attention
|
MLAAttention
,
torch
.
Tensor
]:
)
->
tuple
[
dict
|
object
|
None
,
"
Attention | MLAAttention
"
,
torch
.
Tensor
]:
"""Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer
...
...
@@ -782,7 +585,7 @@ def get_attention_context(
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
attn_layer
:
Attention
|
MLAAttention
=
forward_context
.
no_compile_layers
[
layer_name
]
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
return
attn_metadata
,
attn_layer
,
kv_cache
...
...
@@ -914,79 +717,3 @@ direct_register_custom_op(
mutates_args
=
[
"output"
,
"output_block_scale"
],
fake_impl
=
unified_attention_with_output_fake
,
)
@
maybe_transfer_kv_layer
def
unified_mla_attention
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
attn_metadata
,
self
,
kv_cache
=
get_attention_context
(
layer_name
)
output
=
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
)
return
output
def
unified_mla_attention_fake
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_mla_attention"
,
op_func
=
unified_mla_attention
,
mutates_args
=
[],
fake_impl
=
unified_mla_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
@
maybe_transfer_kv_layer
def
unified_mla_attention_with_output
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
attn_metadata
,
self
,
kv_cache
=
get_attention_context
(
layer_name
)
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
,
output
=
output
,
output_scale
=
output_scale
,
output_block_scale
=
output_block_scale
,
)
def
unified_mla_attention_with_output_fake
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"unified_mla_attention_with_output"
,
op_func
=
unified_mla_attention_with_output
,
mutates_args
=
[
"output"
,
"output_block_scale"
],
fake_impl
=
unified_mla_attention_with_output_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/model_executor/layers/attention/chunked_local_attention.py
View file @
a608b4c6
...
...
@@ -4,9 +4,9 @@ import functools
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
from
vllm.config.vllm
import
VllmConfig
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
a608b4c6
...
...
@@ -6,9 +6,9 @@ from copy import copy
import
numpy
as
np
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
...
...
vllm/model_executor/layers/attention/encoder_only_attention.py
View file @
a608b4c6
...
...
@@ -5,9 +5,9 @@ from copy import copy
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
from
vllm.config.vllm
import
VllmConfig
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionMetadata
,
...
...
vllm/attention/
utils/
kv_transfer_utils.py
→
vllm/
model_executor/layers/
attention/kv_transfer_utils.py
View file @
a608b4c6
...
...
@@ -19,7 +19,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
On exit: saves the KV layer to the connector.
"""
# Import at runtime to avoid circular dependency
from
vllm.attention.
layer
import
get_attention_context
from
vllm.
model_executor.layers.
attention.
attention
import
get_attention_context
# Inspect the signature ONCE when the decorator is applied.
sig
=
inspect
.
signature
(
func
)
...
...
vllm/model_executor/layers/attention/mla_attention.py
100755 → 100644
View file @
a608b4c6
...
...
@@ -191,24 +191,38 @@ import functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
ClassVar
,
Generic
,
TypeVar
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
,
cast
if
TYPE_CHECKING
:
from
flashinfer
import
BatchPrefillWithRaggedKVCacheWrapper
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
ModelConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
from
vllm.model_executor.layers.attention.attention
import
(
_init_kv_cache_quant
,
get_attention_context
,
set_default_quant_scales
,
should_load_quant_weights
,
)
from
vllm.model_executor.layers.attention.kv_transfer_utils
import
(
maybe_transfer_kv_layer
,
)
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
...
...
@@ -217,11 +231,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_nvidia_artifactory
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.utils.torch_utils
import
(
direct_register_custom_op
,
kv_cache_dtype_str_to_dtype
,
)
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
,
CommonAttentionMetadata
,
MLAAttentionImpl
,
)
...
...
@@ -234,7 +253,320 @@ from vllm.v1.attention.backends.utils import (
)
from
vllm.v1.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.v1.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheSpec
,
MLAAttentionSpec
,
)
logger
=
init_logger
(
__name__
)
class
MLAAttention
(
nn
.
Module
,
AttentionLayerBase
):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def
__init__
(
self
,
num_heads
:
int
,
scale
:
float
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
|
None
,
kv_lora_rank
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_sparse
:
bool
=
False
,
indexer
:
object
|
None
=
None
,
**
extra_impl_args
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
scale
=
scale
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
head_size
=
kv_lora_rank
+
qk_rope_head_dim
self
.
layer_name
=
prefix
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
calculate_kv_scales
=
False
self
.
quant_config
=
quant_config
# Initialize KV cache quantization attributes
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
_init_kv_cache_quant
(
self
,
quant_config
,
prefix
)
dtype
=
torch
.
get_default_dtype
()
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_mla
=
True
,
use_sparse
=
use_sparse
,
)
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
or
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
)
):
logger
.
warning_once
(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
self
.
impl
=
impl_cls
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_size
,
scale
=
self
.
scale
,
num_kv_heads
=
1
,
alibi_slopes
=
None
,
sliding_window
=
None
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
logits_soft_cap
=
None
,
attn_type
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
=
None
,
# MLA Args
q_lora_rank
=
self
.
q_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_head_dim
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
v_head_dim
=
self
.
v_head_dim
,
kv_b_proj
=
kv_b_proj
,
indexer
=
indexer
,
**
extra_impl_args
,
)
self
.
use_direct_call
=
not
current_platform
.
opaque_attention_op
()
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
get_current_vllm_config
().
parallel_config
.
pipeline_parallel_size
)
]
self
.
use_sparse
=
use_sparse
# Initialize q/k/v range constants.
self
.
q_range
=
torch
.
tensor
(
envs
.
Q_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
self
.
k_range
=
torch
.
tensor
(
envs
.
K_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
self
.
v_range
=
torch
.
tensor
(
envs
.
V_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
def
forward
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_shape
:
torch
.
Size
|
None
=
None
,
)
->
torch
.
Tensor
:
if
self
.
calculate_kv_scales
:
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
)
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
,
output
=
output
,
)
return
output
else
:
return
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
)
else
:
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch
.
ops
.
vllm
.
unified_mla_attention_with_output
(
q
,
kv_c_normed
,
k_pe
,
output
,
self
.
layer_name
,
)
return
output
else
:
return
torch
.
ops
.
vllm
.
unified_mla_attention
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
,
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
if
hasattr
(
self
.
impl
,
"process_weights_after_loading"
):
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method
=
(
self
.
quant_config
.
get_quant_method
(
self
,
prefix
=
self
.
layer_name
)
if
self
.
quant_config
else
None
)
if
not
should_load_quant_weights
(
quant_method
):
set_default_quant_scales
(
self
,
register_buffer
=
False
)
def
calc_kv_scales
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
)
->
None
:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range
=
getattr
(
self
,
"q_range"
,
torch
.
tensor
(
1.0
))
k_range
=
getattr
(
self
,
"k_range"
,
torch
.
tensor
(
1.0
))
v_range
=
getattr
(
self
,
"v_range"
,
torch
.
tensor
(
1.0
))
self
.
_q_scale
.
copy_
(
torch
.
abs
(
q
).
max
()
/
q_range
)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max
=
torch
.
abs
(
kv_c_normed
).
max
()
self
.
_k_scale
.
copy_
(
kv_abs_max
/
k_range
)
self
.
_v_scale
.
copy_
(
kv_abs_max
/
v_range
)
self
.
_q_scale_float
=
self
.
_q_scale
.
item
()
self
.
_k_scale_float
=
self
.
_k_scale
.
item
()
self
.
_v_scale_float
=
self
.
_v_scale
.
item
()
self
.
calculate_kv_scales
=
False
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
kv_cache_dtype
=
kv_cache_dtype_str_to_dtype
(
self
.
kv_cache_dtype
,
vllm_config
.
model_config
)
return
MLAAttentionSpec
(
block_size
=
vllm_config
.
cache_config
.
block_size
,
num_kv_heads
=
1
,
head_size
=
self
.
head_size
,
dtype
=
kv_cache_dtype
,
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
)
@
maybe_transfer_kv_layer
def
unified_mla_attention
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
attn_metadata
,
self
,
kv_cache
=
get_attention_context
(
layer_name
)
output
=
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
)
return
output
def
unified_mla_attention_fake
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_mla_attention"
,
op_func
=
unified_mla_attention
,
mutates_args
=
[],
fake_impl
=
unified_mla_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
@
maybe_transfer_kv_layer
def
unified_mla_attention_with_output
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
attn_metadata
,
self
,
kv_cache
=
get_attention_context
(
layer_name
)
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
,
output
=
output
,
output_scale
=
output_scale
,
output_block_scale
=
output_block_scale
,
)
def
unified_mla_attention_with_output_fake
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"unified_mla_attention_with_output"
,
op_func
=
unified_mla_attention_with_output
,
mutates_args
=
[
"output"
,
"output_block_scale"
],
fake_impl
=
unified_mla_attention_with_output_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
class
QueryLenSupport
(
Enum
):
...
...
@@ -266,15 +598,12 @@ except ImportError:
from
flash_attn
import
flash_attn_varlen_func
# type: ignore[no-redef]
is_vllm_fa
=
False
try
:
from
flashinfer
import
BatchPrefillWithRaggedKVCacheWrapper
from
flashinfer.prefill
import
cudnn_batch_prefill_with_kv_cache
# noqa: F401
flashinfer_available
=
Tru
e
except
ImportError
:
BatchPrefillWithRaggedKVCacheWrapper
=
object
@
functools
.
cach
e
def
flashinfer_available
()
->
bool
:
import
importlib.util
flashinfer_available
=
Fals
e
return
importlib
.
util
.
find_spec
(
"flashinfer"
)
is
not
Non
e
def
dynamic_per_batched_tensor_quant
(
...
...
@@ -398,8 +727,8 @@ class MLACommonPrefillMetadata:
@
dataclass
class
FlashInferPrefillMetadata
(
MLACommonPrefillMetadata
):
prefill_main
:
BatchPrefillWithRaggedKVCacheWrapper
|
None
=
None
prefill_chunks
:
list
[
BatchPrefillWithRaggedKVCacheWrapper
]
=
field
(
prefill_main
:
"
BatchPrefillWithRaggedKVCacheWrapper | None
"
=
None
prefill_chunks
:
"
list[BatchPrefillWithRaggedKVCacheWrapper]
"
=
field
(
default_factory
=
list
)
...
...
@@ -495,7 +824,7 @@ def use_flashinfer_prefill() -> bool:
vllm_config
=
get_current_vllm_config
()
if
not
(
not
vllm_config
.
attention_config
.
disable_flashinfer_prefill
and
flashinfer_available
and
flashinfer_available
()
and
not
vllm_config
.
attention_config
.
use_cudnn_prefill
and
current_platform
.
is_device_capability_family
(
100
)
):
...
...
@@ -509,7 +838,7 @@ def use_cudnn_prefill() -> bool:
vllm_config
=
get_current_vllm_config
()
return
(
flashinfer_available
flashinfer_available
()
and
vllm_config
.
attention_config
.
use_cudnn_prefill
and
current_platform
.
is_device_capability_family
(
100
)
and
has_nvidia_artifactory
()
...
...
@@ -731,6 +1060,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
has_context
=
True
if
self
.
_fi_prefill_main
is
None
:
from
flashinfer
import
BatchPrefillWithRaggedKVCacheWrapper
self
.
_fi_prefill_main
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
_workspace_buffer
,
"NHD"
,
backend
=
"cutlass"
)
...
...
@@ -739,6 +1070,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
num_chunks
=
chunked_context
.
cu_seq_lens
.
shape
[
0
]
# Allocate more prefill chunk wrappers if needed
if
len
(
self
.
_fi_prefill_chunks
)
<
num_chunks
:
from
flashinfer
import
BatchPrefillWithRaggedKVCacheWrapper
for
_
in
range
(
len
(
self
.
_fi_prefill_chunks
),
num_chunks
):
self
.
_fi_prefill_chunks
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
...
...
@@ -1513,6 +1846,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
):
assert
isinstance
(
prefill
,
CudnnPrefillMetadata
)
assert
prefill
.
query_seq_lens
is
not
None
from
flashinfer.prefill
import
cudnn_batch_prefill_with_kv_cache
output
,
lse
=
cudnn_batch_prefill_with_kv_cache
(
q
=
q
,
k_cache
=
k
,
...
...
@@ -1572,6 +1907,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert
prefill
.
chunked_context
is
not
None
assert
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
]
is
not
None
assert
prefill
.
query_seq_lens
is
not
None
from
flashinfer.prefill
import
cudnn_batch_prefill_with_kv_cache
return
cudnn_batch_prefill_with_kv_cache
(
q
=
q
,
k_cache
=
k
,
...
...
vllm/model_executor/layers/attention/static_sink_attention.py
View file @
a608b4c6
...
...
@@ -4,11 +4,11 @@ import functools
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
(
...
...
vllm/model_executor/layers/mla.py
View file @
a608b4c6
...
...
@@ -4,9 +4,9 @@ from dataclasses import dataclass
import
torch
from
vllm.attention.layer
import
MLAAttention
from
vllm.config
import
CacheConfig
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.attention
import
MLAAttention
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
a608b4c6
...
...
@@ -19,12 +19,12 @@ from compressed_tensors.quantization import (
from
compressed_tensors.transform
import
TransformConfig
import
vllm.envs
as
envs
from
vllm.attention.layer
import
Attention
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
a608b4c6
...
...
@@ -11,9 +11,9 @@ import vllm.envs as envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.attention.layer
import
Attention
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
)
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
a608b4c6
...
...
@@ -11,8 +11,8 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.attention.layer
import
Attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
a608b4c6
...
...
@@ -7,9 +7,9 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEConfig
,
...
...
vllm/model_executor/layers/quantization/petit.py
View file @
a608b4c6
...
...
@@ -8,8 +8,8 @@ import regex as re
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.attention.layer
import
Attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
...
...
vllm/model_executor/layers/quantization/ptpc_fp8.py
View file @
a608b4c6
...
...
@@ -7,8 +7,8 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.layer
import
Attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
QuantizeMethodBase
...
...
vllm/model_executor/layers/quantization/quark/quark.py
View file @
a608b4c6
...
...
@@ -6,8 +6,8 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
a608b4c6
...
...
@@ -11,9 +11,9 @@ import torch
from
torch
import
nn
from
typing_extensions
import
assert_never
from
vllm.attention.layer
import
Attention
,
MLAAttention
from
vllm.config
import
ModelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
...
...
vllm/model_executor/models/afmoe.py
View file @
a608b4c6
...
...
@@ -9,7 +9,6 @@ from itertools import islice
import
torch
from
torch
import
nn
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed
import
(
...
...
@@ -18,6 +17,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
vllm/model_executor/models/aimv2.py
View file @
a608b4c6
...
...
@@ -11,7 +11,7 @@ import torch.nn as nn
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.utils
import
divide
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
.mm_encoder_attention
import
MMEncoderAttention
from
vllm.model_executor.layers.attention
import
MMEncoderAttention
from
vllm.model_executor.layers.conv
import
Conv2dLayer
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
vllm/model_executor/models/apertus.py
View file @
a608b4c6
...
...
@@ -32,12 +32,12 @@ import torch
from
torch
import
nn
from
transformers
import
ApertusConfig
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
XIELU
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
from
vllm.model_executor.layers.attention
import
(
Attention
,
EncoderOnlyAttention
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
vllm/model_executor/models/arctic.py
View file @
a608b4c6
...
...
@@ -8,7 +8,6 @@ from itertools import islice
import
torch
from
torch
import
nn
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
...
...
@@ -19,6 +18,7 @@ from vllm.distributed import (
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
...
...
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment