Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3951d3ea
Unverified
Commit
3951d3ea
authored
Apr 22, 2026
by
Martin Hickey
Committed by
GitHub
Apr 21, 2026
Browse files
[MyPy] Enable mypy for `vllm/model_executor/layers/` (#40159)
Signed-off-by:
Martin Hickey
<
martin.hickey@ie.ibm.com
>
parent
6f2c71be
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
189 additions
and
119 deletions
+189
-119
tools/pre_commit/mypy.py
tools/pre_commit/mypy.py
+0
-1
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+15
-12
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+15
-6
vllm/model_executor/layers/attention/chunked_local_attention.py
...odel_executor/layers/attention/chunked_local_attention.py
+1
-1
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+9
-4
vllm/model_executor/layers/attention/encoder_only_attention.py
...model_executor/layers/attention/encoder_only_attention.py
+2
-2
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+17
-10
vllm/model_executor/layers/fused_moe/all2all_utils.py
vllm/model_executor/layers/fused_moe/all2all_utils.py
+7
-4
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+2
-1
vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py
...xecutor/layers/fused_moe/experts/batched_deep_gemm_moe.py
+3
-3
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+8
-5
vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py
...fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py
+18
-6
vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py
...fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py
+16
-9
vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
...executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
+4
-0
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
...el_executor/layers/fused_moe/runner/default_moe_runner.py
+3
-1
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+25
-18
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-1
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+2
-1
vllm/model_executor/layers/mamba/gdn_linear_attn.py
vllm/model_executor/layers/mamba/gdn_linear_attn.py
+36
-30
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+5
-4
No files found.
tools/pre_commit/mypy.py
View file @
3951d3ea
...
...
@@ -29,7 +29,6 @@ SEPARATE_GROUPS = [
"tests"
,
# v0 related
"vllm/lora"
,
"vllm/model_executor/layers"
,
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
...
...
vllm/model_executor/layers/activation.py
View file @
3951d3ea
...
...
@@ -666,16 +666,7 @@ _ACTIVATION_REGISTRY = LazyDict(
"gelu"
:
lambda
:
GELU
(),
"gelu_fast"
:
lambda
:
FastGELU
(),
"gelu_new"
:
lambda
:
NewGELU
(),
"gelu_pytorch_tanh"
:
lambda
:
(
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger
.
warning_once
(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
),
nn
.
GELU
(
approximate
=
"none"
),
)[
1
]
if
current_platform
.
is_rocm
()
else
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
lambda
:
_get_gelu_pytorch_tanh
(),
"relu"
:
lambda
:
nn
.
ReLU
(),
"relu2"
:
lambda
:
ReLUSquaredActivation
(),
"silu"
:
lambda
:
nn
.
SiLU
(),
...
...
@@ -687,6 +678,18 @@ _ACTIVATION_REGISTRY = LazyDict(
)
def
_get_gelu_pytorch_tanh
()
->
nn
.
Module
:
"""Get PyTorch GELU with tanh approximation, with ROCm fallback."""
if
current_platform
.
is_rocm
():
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger
.
warning_once
(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
)
return
nn
.
GELU
(
approximate
=
"none"
)
return
nn
.
GELU
(
approximate
=
"tanh"
)
def
get_act_fn
(
act_fn_name
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn_name
=
act_fn_name
.
lower
()
...
...
@@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
return
_ACTIVATION_REGISTRY
[
act_fn_name
]
_ACTIVATION_AND_MUL_REGISTRY
=
LazyDict
(
_ACTIVATION_AND_MUL_REGISTRY
:
LazyDict
[
nn
.
Module
]
=
LazyDict
(
{
"gelu"
:
lambda
:
GeluAndMul
(),
"silu"
:
lambda
:
SiluAndMul
(),
"geglu"
:
lambda
:
GeluAndMul
(),
"swigluoai"
:
lambda
*
args
,
**
kwargs
:
SwigluOAIAndMul
(
*
args
,
**
kwargs
),
"swigluoai"
:
lambda
:
SwigluOAIAndMul
(),
}
)
...
...
vllm/model_executor/layers/attention/attention.py
View file @
3951d3ea
...
...
@@ -33,6 +33,7 @@ from vllm.utils.torch_utils import (
)
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionType
,
)
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
...
...
@@ -209,6 +210,7 @@ class Attention(nn.Module, AttentionLayerBase):
`self.kv_cache`.
"""
super
().
__init__
()
sliding_window
:
int
|
None
if
per_layer_sliding_window
is
not
None
:
# per-layer sliding window
sliding_window
=
per_layer_sliding_window
...
...
@@ -335,7 +337,7 @@ class Attention(nn.Module, AttentionLayerBase):
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
self
.
impl
=
impl_cls
(
# type: ignore[assignment] # impl_cls always returns an AttentionImpl subclass
num_heads
,
head_size
,
scale
,
...
...
@@ -576,7 +578,7 @@ class Attention(nn.Module, AttentionLayerBase):
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
# Block size may get updated after model loading, refresh it
block_size
=
vllm_config
.
cache_config
.
block_size
# Should not be called for enc-dec or encoder-only attention.
...
...
@@ -680,9 +682,16 @@ def get_attention_context(
extracted from the forward context.
"""
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata_raw
=
forward_context
.
attn_metadata
attn_metadata
:
AttentionMetadata
if
isinstance
(
attn_metadata_raw
,
dict
):
attn_metadata
=
attn_metadata_raw
[
layer_name
]
elif
isinstance
(
attn_metadata_raw
,
list
):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata
=
attn_metadata_raw
[
0
][
layer_name
]
else
:
attn_metadata
=
attn_metadata_raw
attn_layer
:
Attention
|
MLAAttention
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
slot_mapping
=
forward_context
.
slot_mapping
...
...
@@ -708,7 +717,7 @@ def unified_kv_cache_update(
assert
hasattr
(
attn_layer
.
impl
,
"do_kv_cache_update"
),
(
f
"
{
attn_layer
.
impl
.
__class__
.
__name__
}
does not support kv cache update"
)
attn_layer
.
impl
.
do_kv_cache_update
(
attn_layer
.
impl
.
do_kv_cache_update
(
# type: ignore[attr-defined]
attn_layer
,
key
,
value
,
...
...
vllm/model_executor/layers/attention/chunked_local_attention.py
View file @
3951d3ea
...
...
@@ -29,7 +29,7 @@ from vllm.v1.kv_cache_interface import (
@
functools
.
lru_cache
def
create_chunked_local_attention_backend
(
underlying_attn_backend
:
AttentionBackend
,
underlying_attn_backend
:
type
[
AttentionBackend
]
,
attention_chunk_size
:
int
,
)
->
type
[
AttentionBackend
]:
prefix
=
f
"ChunkedLocalAttention_
{
attention_chunk_size
}
_"
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
3951d3ea
...
...
@@ -72,7 +72,7 @@ def _get_cross_slot_mapping(
@
functools
.
lru_cache
def
create_cross_attention_backend
(
underlying_attn_backend
:
AttentionBackend
,
underlying_attn_backend
:
type
[
AttentionBackend
]
,
)
->
type
[
AttentionBackend
]:
prefix
=
"CrossAttention_"
underlying_builder
=
underlying_attn_backend
.
get_builder_cls
()
...
...
@@ -87,6 +87,7 @@ def create_cross_attention_backend(
)
->
AttentionMetadata
:
new_metadata
=
copy
(
common_attn_metadata
)
new_metadata
.
causal
=
False
assert
new_metadata
.
encoder_seq_lens_cpu
is
not
None
max_encoder_len
=
int
(
new_metadata
.
encoder_seq_lens_cpu
.
max
())
new_metadata
.
max_seq_len
=
max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
...
...
@@ -118,7 +119,7 @@ def create_cross_attention_backend(
self
.
device
,
)
attn_metadata
=
super
().
build
(
common_prefix_len
,
new_metadata
,
fast_build
)
attn_metadata
.
slot_mapping
=
slot_mapping
attn_metadata
.
slot_mapping
=
slot_mapping
# type: ignore[attr-defined]
return
attn_metadata
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
...
...
@@ -144,8 +145,12 @@ def create_cross_attention_backend(
and
key
is
not
None
and
value
is
not
None
):
self
.
do_kv_cache_update
(
layer
,
key
,
value
,
kv_cache
,
attn_metadata
.
slot_mapping
self
.
do_kv_cache_update
(
# type: ignore[attr-defined]
layer
,
key
,
value
,
kv_cache
,
attn_metadata
.
slot_mapping
,
# type: ignore[attr-defined]
)
return
super
().
forward
(
...
...
vllm/model_executor/layers/attention/encoder_only_attention.py
View file @
3951d3ea
...
...
@@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
@
functools
.
lru_cache
def
create_encoder_only_attention_backend
(
underlying_attn_backend
:
AttentionBackend
,
underlying_attn_backend
:
type
[
AttentionBackend
]
,
)
->
type
[
AttentionBackend
]:
prefix
=
"EncoderOnlyAttention_"
underlying_builder
=
underlying_attn_backend
.
get_builder_cls
()
...
...
@@ -93,6 +93,6 @@ class EncoderOnlyAttention(Attention):
**
kwargs
,
)
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
# Does not need KV cache
return
None
vllm/model_executor/layers/attention/mla_attention.py
View file @
3951d3ea
...
...
@@ -389,7 +389,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
self
.
impl
=
impl_cls
(
self
.
impl
=
impl_cls
(
# type: ignore[assignment] # impl_cls always returns an MLAAttentionImpl subclass
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_size
,
scale
=
self
.
scale
,
...
...
@@ -485,16 +485,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
attn_metadata_raw
=
forward_context
.
attn_metadata
attn_metadata
:
MLACommonMetadata
if
isinstance
(
attn_metadata_raw
,
dict
):
attn_metadata
=
attn_metadata_raw
[
self
.
layer_name
]
# type: ignore[assignment]
elif
isinstance
(
attn_metadata_raw
,
list
):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata
=
attn_metadata_raw
[
0
][
self
.
layer_name
]
# type: ignore[assignment]
else
:
attn_metadata
=
attn_metadata_raw
self_kv_cache
=
self
.
kv_cache
slot_mapping
=
forward_context
.
slot_mapping
assert
isinstance
(
slot_mapping
,
dict
),
(
f
"Expected slot_mapping to be a dict, got
{
type
(
slot_mapping
)
}
. "
)
self
.
impl
.
do_kv_cache_update
(
self
.
impl
.
do_kv_cache_update
(
# type: ignore[attr-defined]
kv_c_normed
,
k_pe
,
self_kv_cache
,
...
...
@@ -612,7 +619,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
num_mha_tokens
=
q
.
size
(
0
)
-
num_mqa_tokens
if
num_mha_tokens
>
0
:
self
.
impl
.
forward_mha
(
self
.
impl
.
forward_mha
(
# type: ignore[attr-defined]
q
[
num_mqa_tokens
:],
k_c_normed
[
num_mqa_tokens
:],
k_pe
[
num_mqa_tokens
:],
...
...
@@ -695,7 +702,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# call decode attn
if
not
is_sparse_impl
:
assert
attn_metadata
.
decode
is
not
None
attn_out
,
lse
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
attn_out
,
lse
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
# type: ignore[attr-defined]
# correct dcp attn_out with lse.
if
self
.
impl
.
dcp_world_size
>
1
:
...
...
@@ -1053,9 +1060,9 @@ except ImportError:
"AITER_MLA backends use aiter kernels instead."
)
elif
current_platform
.
is_xpu
():
from
vllm._xpu_ops
import
xpu_ops
as
ops
from
vllm._xpu_ops
import
xpu_ops
flash_attn_varlen_func
=
ops
.
flash_attn_varlen_func
# type: ignore[no-redef]
flash_attn_varlen_func
=
xpu_
ops
.
flash_attn_varlen_func
# type: ignore[no-redef
,attr-defined,assignment
]
def
dynamic_per_batched_tensor_quant
(
...
...
@@ -1988,7 +1995,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert
isinstance
(
attn_metadata
.
prefill
,
FlashInferPrefillMetadata
)
self
.
_build_fi_prefill_wrappers
(
attn_metadata
.
prefill
)
return
attn_metadata
return
attn_metadata
# type: ignore[return-value]
def
reorg_kvcache
(
...
...
vllm/model_executor/layers/fused_moe/all2all_utils.py
View file @
3951d3ea
...
...
@@ -117,17 +117,20 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
device_communicator
=
get_ep_group
().
device_communicator
assert
device_communicator
is
not
None
assert
device_communicator
.
all2all_manager
is
not
None
return
make_moe_prepare_and_finalize_naive_dp_ep
(
is_sequence_parallel
=
moe
.
moe_parallel_config
.
is_sequence_parallel
,
num_dispatchers
=
(
get_ep_group
().
device_communicator
.
all2all_manager
.
world_size
),
num_dispatchers
=
(
device_communicator
.
all2all_manager
.
world_size
),
use_monolithic
=
use_monolithic
,
)
else
:
return
make_moe_prepare_and_finalize_no_dp_ep
(
use_monolithic
)
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
device_communicator
=
get_ep_group
().
device_communicator
assert
device_communicator
is
not
None
all2all_manager
=
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
prepare_finalize
:
FusedMoEPrepareAndFinalize
|
None
=
None
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
3951d3ea
...
...
@@ -7,6 +7,7 @@ from typing import Union
import
torch
from
vllm.config
import
ParallelConfig
,
SchedulerConfig
from
vllm.config.kernel
import
MoEBackend
from
vllm.distributed
import
get_dp_group
,
get_pcp_group
,
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
...
...
@@ -1192,7 +1193,7 @@ class FusedMoEConfig:
# Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded
:
int
|
None
=
None
moe_backend
:
str
=
"auto"
moe_backend
:
MoEBackend
=
"auto"
max_num_tokens
:
int
=
SchedulerConfig
.
DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
has_bias
:
bool
=
False
is_act_and_mul
:
bool
=
True
...
...
vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py
View file @
3951d3ea
...
...
@@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant(
DeepGemmQuantScaleFMT
.
UE8M0
,
]
cuda_arch
=
current_platform
.
get_device_capability
(
device_
id
=
y
.
device
.
index
)
.
to_int
()
device_capability
=
current_platform
.
get_device_capability
(
device_id
=
y
.
device
.
index
)
assert
device_
capability
is
not
None
cuda_arch
=
device_capability
.
to_int
()
if
current_platform
.
is_cuda
()
and
cuda_arch
>=
80
:
torch
.
ops
.
_C
.
persistent_masked_m_silu_mul_quant
(
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
View file @
3951d3ea
...
...
@@ -7,6 +7,7 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm.config.kernel
import
MoEBackend
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoEConfig
,
...
...
@@ -146,7 +147,7 @@ def backend_to_kernel_cls(
raise
ValueError
(
f
"Unknown MXFP4 MoE backend:
{
backend
.
value
}
"
)
def
map_mxfp4_backend
(
runner_backend
:
str
)
->
Mxfp4MoeBackend
:
def
map_mxfp4_backend
(
runner_backend
:
MoEBackend
)
->
Mxfp4MoeBackend
:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping
:
dict
[
str
,
Mxfp4MoeBackend
]
=
{
"flashinfer_trtllm"
:
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_BF16
,
...
...
@@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend(
Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported
=
has_triton_kernels
()
and
(
9
,
0
,
)
<=
current_platform
.
get_device_capability
()
<
(
11
,
0
)
device_capability
=
current_platform
.
get_device_capability
()
triton_kernels_supported
=
(
has_triton_kernels
()
and
device_capability
is
not
None
and
(
9
,
0
)
<=
device_capability
<
(
11
,
0
)
)
# LoRA: separate experts backend path
if
config
.
is_lora_enabled
:
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py
View file @
3951d3ea
...
...
@@ -4,6 +4,9 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.distributed
import
get_ep_group
from
vllm.distributed.device_communicators.base_device_communicator
import
(
All2AllManagerBase
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
...
...
@@ -11,12 +14,16 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def
get_local_sizes
():
return
get_forward_context
().
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
return
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
class
FlashInferNVLinkOneSidedPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalizeModular
):
"""FlashInfer implementation using the Moe AlltoAll kernel."""
all2all_manager
:
All2AllManagerBase
def
__init__
(
self
,
max_num_tokens
:
int
,
...
...
@@ -32,8 +39,12 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
self
.
hidden_size
=
hidden_size
self
.
num_dispatchers_
=
num_dispatchers
self
.
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
self
.
all2all_manager
.
initialize
(
device_communicator
=
get_ep_group
().
device_communicator
assert
device_communicator
is
not
None
all2all_manager
=
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
all2all_manager
=
all2all_manager
self
.
all2all_manager
.
initialize
(
# type: ignore[attr-defined]
max_num_tokens
=
self
.
max_num_tokens
,
top_k
=
self
.
top_k
,
num_experts
=
self
.
num_experts
,
...
...
@@ -97,7 +108,8 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
payloads
.
append
(
topk_ids
)
payloads
.
append
(
topk_weights
)
recv_payloads
=
self
.
all2all_manager
.
moe_alltoall
.
dispatch
(
assert
self
.
all2all_manager
.
moe_alltoall
is
not
None
# type: ignore[attr-defined]
recv_payloads
=
self
.
all2all_manager
.
moe_alltoall
.
dispatch
(
# type: ignore[attr-defined]
token_selected_experts
=
topk_ids
,
input_payloads
=
payloads
,
runtime_max_tokens_per_rank
=
self
.
runtime_max_tokens_per_rank
,
...
...
@@ -131,7 +143,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
assert
self
.
all2all_manager
.
moe_alltoall
is
not
None
assert
self
.
all2all_manager
.
moe_alltoall
is
not
None
# type: ignore[attr-defined]
ep_size
=
self
.
all2all_manager
.
world_size
hidden_size
=
fused_expert_output
.
shape
[
-
1
]
...
...
@@ -139,7 +151,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
ep_size
,
self
.
runtime_max_tokens_per_rank
,
hidden_size
)
combined_output
=
self
.
all2all_manager
.
moe_alltoall
.
combine
(
combined_output
=
self
.
all2all_manager
.
moe_alltoall
.
combine
(
# type: ignore[attr-defined]
payload
=
fused_expert_output
,
runtime_max_tokens_per_rank
=
self
.
runtime_max_tokens_per_rank
,
)
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py
View file @
3951d3ea
...
...
@@ -15,19 +15,26 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def
get_local_sizes
():
return
get_forward_context
().
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
return
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
class
FlashInferNVLinkTwoSidedPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalizeModular
):
"""Base class for FlashInfer MoE prepare and finalize operations."""
all2all_manager
:
All2AllManagerBase
def
__init__
(
self
,
num_dispatchers
:
int
=
1
,
):
super
().
__init__
()
self
.
num_dispatchers_
=
num_dispatchers
self
.
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
device_communicator
=
get_ep_group
().
device_communicator
assert
device_communicator
is
not
None
assert
device_communicator
.
all2all_manager
is
not
None
self
.
all2all_manager
=
device_communicator
.
all2all_manager
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -129,7 +136,7 @@ def flashinfer_alltoall_dispatch(
):
from
flashinfer.comm.trtllm_alltoall
import
MnnvlMoe
assert
all2all_manager
.
ensure_alltoall_workspace_initialized
(),
(
assert
all2all_manager
.
ensure_alltoall_workspace_initialized
(),
(
# type: ignore[attr-defined]
"FlashInfer AllToAll workspace not available"
)
...
...
@@ -144,7 +151,7 @@ def flashinfer_alltoall_dispatch(
topk_ids
,
topk_weights
,
None
,
all2all_manager
.
prepare_workspace_tensor
,
all2all_manager
.
prepare_workspace_tensor
,
# type: ignore[attr-defined]
max_num_token
,
ep_rank
,
ep_size
,
...
...
@@ -172,7 +179,7 @@ def flashinfer_alltoall_dispatch(
x
=
MnnvlMoe
.
mnnvl_moe_alltoallv
(
x
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
all2all_manager
.
workspace_tensor
,
# type: ignore[attr-defined]
ep_rank
,
ep_size
,
)
...
...
@@ -180,7 +187,7 @@ def flashinfer_alltoall_dispatch(
x_sf
=
MnnvlMoe
.
mnnvl_moe_alltoallv
(
x_sf
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
all2all_manager
.
workspace_tensor
,
# type: ignore[attr-defined]
ep_rank
,
ep_size
,
)
...
...
@@ -196,7 +203,7 @@ def flashinfer_alltoall_dispatch(
x
=
MnnvlMoe
.
mnnvl_moe_alltoallv
(
x
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
all2all_manager
.
workspace_tensor
,
# type: ignore[attr-defined]
ep_rank
,
ep_size
,
)
...
...
@@ -212,13 +219,13 @@ def flashinfer_alltoall_combine(
):
from
flashinfer.comm.trtllm_alltoall
import
MnnvlMoe
assert
all2all_manager
.
ensure_alltoall_workspace_initialized
(),
(
assert
all2all_manager
.
ensure_alltoall_workspace_initialized
(),
(
# type: ignore[attr-defined]
"FlashInfer AllToAll workspace not available"
)
return
MnnvlMoe
.
mnnvl_moe_alltoallv_combine
(
output
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
all2all_manager
.
workspace_tensor
,
# type: ignore[attr-defined]
ep_rank
=
all2all_manager
.
rank
,
ep_size
=
all2all_manager
.
world_size
,
top_k
=
top_k
,
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
View file @
3951d3ea
...
...
@@ -132,9 +132,11 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
)
if
scales
is
None
:
assert
len
(
res
)
==
3
a1q
,
topk_weights
,
topk_ids
=
res
a1q_scale
=
None
else
:
assert
len
(
res
)
==
4
a1q
,
topk_weights
,
topk_ids
,
scales
=
res
a1q_scale
=
_unwrap_scale_and_prepare_for_moe
(
scales
,
quant_config
)
...
...
@@ -217,9 +219,11 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
)
if
scales
is
None
:
assert
len
(
res
)
==
2
a1q
,
router_logits
=
res
a1q_scale
=
None
else
:
assert
len
(
res
)
==
3
a1q
,
router_logits
,
scales
=
res
a1q_scale
=
_unwrap_scale_and_prepare_for_moe
(
scales
,
quant_config
)
...
...
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
View file @
3951d3ea
...
...
@@ -54,11 +54,13 @@ class DefaultMoERunner(MoERunnerBase):
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if
self
.
do_naive_dispatch_combine
:
hidden_states
,
router_logit
s
=
get_ep_group
().
dispatch_router_logits
(
re
s
=
get_ep_group
().
dispatch_router_logits
(
hidden_states
,
router_logits
,
self
.
moe_config
.
is_sequence_parallel
,
)
assert
len
(
res
)
==
2
hidden_states
,
router_logits
=
res
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
...
...
vllm/model_executor/layers/kda.py
View file @
3951d3ea
...
...
@@ -16,7 +16,6 @@ from vllm.logger import init_logger
from
vllm.model_executor.model_loader.weight_utils
import
sharded_weight_loader
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
from
.fla.ops.kda
import
(
...
...
@@ -123,7 +122,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
self
.
cache_config
=
cache_config
if
model_config
is
None
:
raise
ValueError
(
"model_config must be provided"
)
kda_config
=
model_config
.
linear_attn_config
kda_config
=
model_config
.
linear_attn_config
# type: ignore[attr-defined]
self
.
head_dim
=
kda_config
[
"head_dim"
]
self
.
num_heads
=
kda_config
[
"num_heads"
]
self
.
layer_idx
=
layer_idx
...
...
@@ -297,19 +296,21 @@ class KimiDeltaAttention(nn.Module, MambaBase):
core_attn_out
:
torch
.
Tensor
,
)
->
None
:
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
_raw
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
if
attn_metadata
_raw
is
None
:
# # V1 profile run
return
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
GDNAttentionMetadata
)
has_initial_state
=
attn_metadata
.
has_initial_state
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
assert
isinstance
(
attn_metadata_raw
,
dict
)
attn_metadata_narrowed
=
attn_metadata_raw
[
self
.
prefix
]
assert
isinstance
(
attn_metadata_narrowed
,
GDNAttentionMetadata
)
has_initial_state
=
attn_metadata_narrowed
.
has_initial_state
non_spec_query_start_loc
=
attn_metadata_narrowed
.
non_spec_query_start_loc
non_spec_state_indices_tensor
=
(
attn_metadata_narrowed
.
non_spec_state_indices_tensor
)
# noqa: E501
num_actual_tokens
=
attn_metadata_narrowed
.
num_actual_tokens
constant_caches
=
self
.
kv_cache
q_proj_states
=
q_proj_states
[:
num_actual_tokens
]
...
...
@@ -335,7 +336,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
v_conv_weights
=
self
.
v_conv1d
.
weight
.
view
(
self
.
v_conv1d
.
weight
.
size
(
0
),
self
.
v_conv1d
.
weight
.
size
(
2
)
)
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
_narrowed
.
num_prefills
>
0
:
q_proj_states
=
q_proj_states
.
transpose
(
0
,
1
)
k_proj_states
=
k_proj_states
.
transpose
(
0
,
1
)
v_proj_states
=
v_proj_states
.
transpose
(
0
,
1
)
...
...
@@ -348,7 +349,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state
=
has_initial_state
,
cache_indices
=
non_spec_state_indices_tensor
,
query_start_loc
=
non_spec_query_start_loc
,
metadata
=
attn_metadata
,
metadata
=
attn_metadata
_narrowed
,
).
transpose
(
0
,
1
)
k
=
causal_conv1d_fn
(
k_proj_states
,
...
...
@@ -359,7 +360,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state
=
has_initial_state
,
cache_indices
=
non_spec_state_indices_tensor
,
query_start_loc
=
non_spec_query_start_loc
,
metadata
=
attn_metadata
,
metadata
=
attn_metadata
_narrowed
,
).
transpose
(
0
,
1
)
v
=
causal_conv1d_fn
(
v_proj_states
,
...
...
@@ -370,11 +371,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state
=
has_initial_state
,
cache_indices
=
non_spec_state_indices_tensor
,
query_start_loc
=
non_spec_query_start_loc
,
metadata
=
attn_metadata
,
metadata
=
attn_metadata
_narrowed
,
).
transpose
(
0
,
1
)
else
:
assert
non_spec_state_indices_tensor
is
not
None
decode_conv_indices
=
non_spec_state_indices_tensor
[
:
attn_metadata
.
num_actual_tokens
:
attn_metadata
_narrowed
.
num_actual_tokens
]
q
=
causal_conv1d_update
(
q_proj_states
,
...
...
@@ -408,7 +410,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda
x
:
rearrange
(
x
,
"n (h d) -> 1 n h d"
,
d
=
self
.
head_dim
),
(
q
,
k
,
v
)
)
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata_narrowed
.
num_prefills
>
0
:
assert
non_spec_state_indices_tensor
is
not
None
assert
has_initial_state
is
not
None
zero_idx
=
non_spec_state_indices_tensor
[
~
has_initial_state
]
recurrent_state
[
zero_idx
]
=
0
initial_state
=
recurrent_state
[
non_spec_state_indices_tensor
].
contiguous
()
...
...
@@ -429,6 +433,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
# Init cache
recurrent_state
[
non_spec_state_indices_tensor
]
=
last_recurrent_state
else
:
assert
non_spec_query_start_loc
is
not
None
(
core_attn_out_non_spec
,
last_recurrent_state
,
...
...
@@ -440,7 +445,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta
=
beta
,
initial_state
=
recurrent_state
,
use_qk_l2norm_in_kernel
=
True
,
cu_seqlens
=
non_spec_query_start_loc
[:
attn_metadata
.
num_decodes
+
1
],
cu_seqlens
=
non_spec_query_start_loc
[
:
attn_metadata_narrowed
.
num_decodes
+
1
],
ssm_state_indices
=
non_spec_state_indices_tensor
,
)
core_attn_out
[
0
,
:
num_actual_tokens
]
=
core_attn_out_non_spec
[
...
...
vllm/model_executor/layers/layernorm.py
View file @
3951d3ea
...
...
@@ -76,7 +76,7 @@ def poly_norm(
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
poly_norm
(
ops
.
poly_norm
(
# type: ignore[attr-defined]
out
,
x
,
weight
,
...
...
vllm/model_executor/layers/mamba/abstract.py
View file @
3951d3ea
...
...
@@ -42,9 +42,10 @@ class MambaBase(AttentionLayerBase):
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
mamba_block_size
=
vllm_config
.
cache_config
.
mamba_block_size
assert
mamba_block_size
is
not
None
page_size_padded
=
vllm_config
.
cache_config
.
mamba_page_size_padded
return
MambaSpec
(
shapes
=
self
.
get_state_shape
(),
shapes
=
tuple
(
self
.
get_state_shape
()
)
,
dtypes
=
self
.
get_state_dtype
(),
block_size
=
mamba_block_size
,
page_size_padded
=
page_size_padded
,
...
...
vllm/model_executor/layers/mamba/gdn_linear_attn.py
View file @
3951d3ea
...
...
@@ -62,7 +62,6 @@ from vllm.utils.torch_utils import (
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
logger
=
init_logger
(
__name__
)
...
...
@@ -121,9 +120,9 @@ def fi_chunk_gated_delta_rule(
class
ChunkGatedDeltaRule
(
CustomOp
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
backend_cf
g
=
get_current_vllm_config
().
additional_config
.
get
(
"gdn_prefill_backend"
,
"auto"
)
additional_confi
g
=
get_current_vllm_config
().
additional_config
assert
isinstance
(
additional_config
,
dict
)
backend_cfg
=
additional_config
.
get
(
"gdn_prefill_backend"
,
"auto"
)
backend
=
str
(
backend_cfg
).
strip
().
lower
()
supports_flashinfer
=
(
...
...
@@ -621,18 +620,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# Part 2: Core Attention
# ============================================================
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
_raw
=
forward_context
.
attn_metadata
core_attn_out
=
torch
.
zeros
(
(
num_tokens
,
self
.
num_v_heads
//
self
.
tp_size
,
self
.
head_v_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
z
=
torch
.
empty_like
(
core_attn_out
)
if
attn_metadata
is
not
None
:
attn_metadata
=
attn_metadata
[
self
.
prefix
]
if
attn_metadata_raw
is
not
None
:
assert
isinstance
(
attn_metadata_raw
,
dict
)
attn_metadata
=
attn_metadata_raw
[
self
.
prefix
]
# TODO: xpu does not support this param yet
spec_sequence_masks
=
attn_metadata
.
spec_sequence_masks
spec_sequence_masks
=
attn_metadata
.
spec_sequence_masks
# type: ignore[attr-defined]
assert
spec_sequence_masks
is
None
conv_weights
=
self
.
conv1d
.
weight
.
view
(
...
...
@@ -658,12 +658,12 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
activation
=
self
.
activation
,
A_log
=
self
.
A_log
,
dt_bias
=
self
.
dt_bias
,
num_prefills
=
attn_metadata
.
num_prefills
,
num_decodes
=
attn_metadata
.
num_decodes
,
has_initial_state
=
attn_metadata
.
has_initial_state
,
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
,
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
,
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
,
num_prefills
=
attn_metadata
.
num_prefills
,
# type: ignore[attr-defined]
num_decodes
=
attn_metadata
.
num_decodes
,
# type: ignore[attr-defined]
has_initial_state
=
attn_metadata
.
has_initial_state
,
# type: ignore[attr-defined]
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
,
# type: ignore[attr-defined]
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
,
# type: ignore[attr-defined]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
,
# type: ignore[attr-defined]
tp_size
=
self
.
tp_size
,
reorder_input
=
not
self
.
gqa_interleaved_layout
,
)
...
...
@@ -792,16 +792,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
core_attn_out
:
torch
.
Tensor
,
):
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
_raw
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
if
attn_metadata
_raw
is
None
:
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self
.
_warmup_prefill_kernels
(
mixed_qkv
)
return
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
_raw
,
dict
)
attn_metadata
=
attn_metadata
_raw
[
self
.
prefix
]
# type: ignore[index]
assert
isinstance
(
attn_metadata
,
GDNAttentionMetadata
)
if
(
...
...
@@ -860,14 +860,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# 1.1: Process the multi-query part
if
spec_sequence_masks
is
not
None
:
# spec_state_indices_tensor is always set when spec_sequence_masks is set
assert
spec_state_indices_tensor
is
not
None
mixed_qkv_spec
=
causal_conv1d_update
(
mixed_qkv_spec
,
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
spec_state_indices_tensor
[:,
0
][
:
attn_metadata
.
num_spec_decodes
conv_state_indices
=
spec_state_indices_tensor
[:,
0
][
# type: ignore[index]
:
attn_metadata
.
num_spec_decodes
# type: ignore[attr-defined]
],
num_accepted_tokens
=
num_accepted_tokens
,
query_start_loc
=
spec_query_start_loc
,
...
...
@@ -900,8 +902,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
non_spec_state_indices_tensor
[
:
attn_metadata
.
num_actual_tokens
conv_state_indices
=
non_spec_state_indices_tensor
[
# type: ignore[index]
:
attn_metadata
.
num_actual_tokens
# type: ignore[attr-defined]
],
validate_data
=
True
,
)
...
...
@@ -965,8 +967,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
v
=
value_spec
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
cu_seqlens
=
spec_query_start_loc
[
:
attn_metadata
.
num_spec_decodes
+
1
cu_seqlens
=
spec_query_start_loc
[
# type: ignore[index]
:
attn_metadata
.
num_spec_decodes
+
1
# type: ignore[attr-defined]
],
ssm_state_indices
=
spec_state_indices_tensor
,
num_accepted_tokens
=
num_accepted_tokens
,
...
...
@@ -978,8 +981,10 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# 2.2: Process the remaining part
if
attn_metadata
.
num_prefills
>
0
:
initial_state
=
ssm_state
[
non_spec_state_indices_tensor
].
contiguous
()
initial_state
[
~
has_initial_state
,
...]
=
0
assert
non_spec_state_indices_tensor
is
not
None
initial_state
=
ssm_state
[
non_spec_state_indices_tensor
].
contiguous
()
# type: ignore[index]
assert
has_initial_state
is
not
None
initial_state
[
~
has_initial_state
,
...]
=
0
# type: ignore[operator]
(
core_attn_out_non_spec
,
last_recurrent_state
,
...
...
@@ -1012,8 +1017,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
v
=
value_non_spec
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
cu_seqlens
=
non_spec_query_start_loc
[
:
attn_metadata
.
num_decodes
+
1
cu_seqlens
=
non_spec_query_start_loc
[
# type: ignore[index]
:
attn_metadata
.
num_decodes
+
1
# type: ignore[attr-defined]
],
ssm_state_indices
=
non_spec_state_indices_tensor
,
use_qk_l2norm_in_kernel
=
True
,
...
...
@@ -1073,7 +1079,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
non_spec_state_indices_tensor
[:
num_actual_tokens
],
conv_state_indices
=
non_spec_state_indices_tensor
[:
num_actual_tokens
],
# type: ignore[index]
validate_data
=
False
,
)
out_buf
=
core_attn_out
[:
num_actual_tokens
].
unsqueeze
(
1
)
...
...
@@ -1086,7 +1092,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
scale
=
self
.
head_k_dim
**-
0.5
,
initial_state
=
ssm_state
,
out
=
out_buf
,
ssm_state_indices
=
non_spec_state_indices_tensor
[:
num_actual_tokens
],
ssm_state_indices
=
non_spec_state_indices_tensor
[:
num_actual_tokens
],
# type: ignore[index]
use_qk_l2norm_in_kernel
=
True
,
)
return
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
3951d3ea
...
...
@@ -396,10 +396,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
None
:
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata_raw
=
forward_context
.
attn_metadata
attn_metadata
:
AttentionMetadata
|
None
=
None
if
attn_metadata_raw
is
not
None
:
assert
isinstance
(
attn_metadata_raw
,
dict
)
attn_metadata
=
attn_metadata_raw
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
LinearAttentionMetadata
)
num_actual_tokens
=
(
attn_metadata
.
num_prefill_tokens
+
attn_metadata
.
num_decode_tokens
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment