Unverified Commit 3951d3ea authored by Martin Hickey's avatar Martin Hickey Committed by GitHub
Browse files

[MyPy] Enable mypy for `vllm/model_executor/layers/` (#40159)


Signed-off-by: default avatarMartin Hickey <martin.hickey@ie.ibm.com>
parent 6f2c71be
...@@ -29,7 +29,6 @@ SEPARATE_GROUPS = [ ...@@ -29,7 +29,6 @@ SEPARATE_GROUPS = [
"tests", "tests",
# v0 related # v0 related
"vllm/lora", "vllm/lora",
"vllm/model_executor/layers",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
......
...@@ -666,16 +666,7 @@ _ACTIVATION_REGISTRY = LazyDict( ...@@ -666,16 +666,7 @@ _ACTIVATION_REGISTRY = LazyDict(
"gelu": lambda: GELU(), "gelu": lambda: GELU(),
"gelu_fast": lambda: FastGELU(), "gelu_fast": lambda: FastGELU(),
"gelu_new": lambda: NewGELU(), "gelu_new": lambda: NewGELU(),
"gelu_pytorch_tanh": lambda: ( "gelu_pytorch_tanh": lambda: _get_gelu_pytorch_tanh(),
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
),
nn.GELU(approximate="none"),
)[1]
if current_platform.is_rocm()
else nn.GELU(approximate="tanh"),
"relu": lambda: nn.ReLU(), "relu": lambda: nn.ReLU(),
"relu2": lambda: ReLUSquaredActivation(), "relu2": lambda: ReLUSquaredActivation(),
"silu": lambda: nn.SiLU(), "silu": lambda: nn.SiLU(),
...@@ -687,6 +678,18 @@ _ACTIVATION_REGISTRY = LazyDict( ...@@ -687,6 +678,18 @@ _ACTIVATION_REGISTRY = LazyDict(
) )
def _get_gelu_pytorch_tanh() -> nn.Module:
"""Get PyTorch GELU with tanh approximation, with ROCm fallback."""
if current_platform.is_rocm():
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
)
return nn.GELU(approximate="none")
return nn.GELU(approximate="tanh")
def get_act_fn(act_fn_name: str) -> nn.Module: def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name.""" """Get an activation function by name."""
act_fn_name = act_fn_name.lower() act_fn_name = act_fn_name.lower()
...@@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module: ...@@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
return _ACTIVATION_REGISTRY[act_fn_name] return _ACTIVATION_REGISTRY[act_fn_name]
_ACTIVATION_AND_MUL_REGISTRY = LazyDict( _ACTIVATION_AND_MUL_REGISTRY: LazyDict[nn.Module] = LazyDict(
{ {
"gelu": lambda: GeluAndMul(), "gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(), "silu": lambda: SiluAndMul(),
"geglu": lambda: GeluAndMul(), "geglu": lambda: GeluAndMul(),
"swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), "swigluoai": lambda: SwigluOAIAndMul(),
} }
) )
......
...@@ -33,6 +33,7 @@ from vllm.utils.torch_utils import ( ...@@ -33,6 +33,7 @@ from vllm.utils.torch_utils import (
) )
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata,
AttentionType, AttentionType,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
...@@ -209,6 +210,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -209,6 +210,7 @@ class Attention(nn.Module, AttentionLayerBase):
`self.kv_cache`. `self.kv_cache`.
""" """
super().__init__() super().__init__()
sliding_window: int | None
if per_layer_sliding_window is not None: if per_layer_sliding_window is not None:
# per-layer sliding window # per-layer sliding window
sliding_window = per_layer_sliding_window sliding_window = per_layer_sliding_window
...@@ -335,7 +337,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -335,7 +337,7 @@ class Attention(nn.Module, AttentionLayerBase):
cache_config.enable_prefix_caching = False cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls() impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls( self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an AttentionImpl subclass
num_heads, num_heads,
head_size, head_size,
scale, scale,
...@@ -576,7 +578,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -576,7 +578,7 @@ class Attention(nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> type[AttentionBackend]: def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Block size may get updated after model loading, refresh it # Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention. # Should not be called for enc-dec or encoder-only attention.
...@@ -680,9 +682,16 @@ def get_attention_context( ...@@ -680,9 +682,16 @@ def get_attention_context(
extracted from the forward context. extracted from the forward context.
""" """
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata_raw = forward_context.attn_metadata
if isinstance(attn_metadata, dict): attn_metadata: AttentionMetadata
attn_metadata = attn_metadata[layer_name] if isinstance(attn_metadata_raw, dict):
attn_metadata = attn_metadata_raw[layer_name]
elif isinstance(attn_metadata_raw, list):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata = attn_metadata_raw[0][layer_name]
else:
attn_metadata = attn_metadata_raw
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
...@@ -708,7 +717,7 @@ def unified_kv_cache_update( ...@@ -708,7 +717,7 @@ def unified_kv_cache_update(
assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update" f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
) )
attn_layer.impl.do_kv_cache_update( attn_layer.impl.do_kv_cache_update( # type: ignore[attr-defined]
attn_layer, attn_layer,
key, key,
value, value,
......
...@@ -29,7 +29,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -29,7 +29,7 @@ from vllm.v1.kv_cache_interface import (
@functools.lru_cache @functools.lru_cache
def create_chunked_local_attention_backend( def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend, underlying_attn_backend: type[AttentionBackend],
attention_chunk_size: int, attention_chunk_size: int,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_" prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
......
...@@ -72,7 +72,7 @@ def _get_cross_slot_mapping( ...@@ -72,7 +72,7 @@ def _get_cross_slot_mapping(
@functools.lru_cache @functools.lru_cache
def create_cross_attention_backend( def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend, underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = "CrossAttention_" prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
...@@ -87,6 +87,7 @@ def create_cross_attention_backend( ...@@ -87,6 +87,7 @@ def create_cross_attention_backend(
) -> AttentionMetadata: ) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata) new_metadata = copy(common_attn_metadata)
new_metadata.causal = False new_metadata.causal = False
assert new_metadata.encoder_seq_lens_cpu is not None
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill) # Any computed tokens indicated decode step>1 (no chunked prefill)
...@@ -118,7 +119,7 @@ def create_cross_attention_backend( ...@@ -118,7 +119,7 @@ def create_cross_attention_backend(
self.device, self.device,
) )
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build) attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping attn_metadata.slot_mapping = slot_mapping # type: ignore[attr-defined]
return attn_metadata return attn_metadata
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by # NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
...@@ -144,8 +145,12 @@ def create_cross_attention_backend( ...@@ -144,8 +145,12 @@ def create_cross_attention_backend(
and key is not None and key is not None
and value is not None and value is not None
): ):
self.do_kv_cache_update( self.do_kv_cache_update( # type: ignore[attr-defined]
layer, key, value, kv_cache, attn_metadata.slot_mapping layer,
key,
value,
kv_cache,
attn_metadata.slot_mapping, # type: ignore[attr-defined]
) )
return super().forward( return super().forward(
......
...@@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import KVCacheSpec ...@@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
@functools.lru_cache @functools.lru_cache
def create_encoder_only_attention_backend( def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend, underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_" prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
...@@ -93,6 +93,6 @@ class EncoderOnlyAttention(Attention): ...@@ -93,6 +93,6 @@ class EncoderOnlyAttention(Attention):
**kwargs, **kwargs,
) )
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Does not need KV cache # Does not need KV cache
return None return None
...@@ -389,7 +389,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -389,7 +389,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
cache_config.enable_prefix_caching = False cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls( self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an MLAAttentionImpl subclass
num_heads=self.num_heads, num_heads=self.num_heads,
head_size=self.head_size, head_size=self.head_size,
scale=self.scale, scale=self.scale,
...@@ -485,16 +485,23 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -485,16 +485,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata_raw = forward_context.attn_metadata
if isinstance(attn_metadata, dict): attn_metadata: MLACommonMetadata
attn_metadata = attn_metadata[self.layer_name] if isinstance(attn_metadata_raw, dict):
attn_metadata = attn_metadata_raw[self.layer_name] # type: ignore[assignment]
elif isinstance(attn_metadata_raw, list):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata = attn_metadata_raw[0][self.layer_name] # type: ignore[assignment]
else:
attn_metadata = attn_metadata_raw
self_kv_cache = self.kv_cache self_kv_cache = self.kv_cache
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
) )
self.impl.do_kv_cache_update( self.impl.do_kv_cache_update( # type: ignore[attr-defined]
kv_c_normed, kv_c_normed,
k_pe, k_pe,
self_kv_cache, self_kv_cache,
...@@ -612,7 +619,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -612,7 +619,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
num_mha_tokens = q.size(0) - num_mqa_tokens num_mha_tokens = q.size(0) - num_mqa_tokens
if num_mha_tokens > 0: if num_mha_tokens > 0:
self.impl.forward_mha( self.impl.forward_mha( # type: ignore[attr-defined]
q[num_mqa_tokens:], q[num_mqa_tokens:],
k_c_normed[num_mqa_tokens:], k_c_normed[num_mqa_tokens:],
k_pe[num_mqa_tokens:], k_pe[num_mqa_tokens:],
...@@ -695,7 +702,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -695,7 +702,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# call decode attn # call decode attn
if not is_sparse_impl: if not is_sparse_impl:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) # type: ignore[attr-defined]
# correct dcp attn_out with lse. # correct dcp attn_out with lse.
if self.impl.dcp_world_size > 1: if self.impl.dcp_world_size > 1:
...@@ -1053,9 +1060,9 @@ except ImportError: ...@@ -1053,9 +1060,9 @@ except ImportError:
"AITER_MLA backends use aiter kernels instead." "AITER_MLA backends use aiter kernels instead."
) )
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops from vllm._xpu_ops import xpu_ops
flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef] flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment]
def dynamic_per_batched_tensor_quant( def dynamic_per_batched_tensor_quant(
...@@ -1988,7 +1995,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -1988,7 +1995,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill) self._build_fi_prefill_wrappers(attn_metadata.prefill)
return attn_metadata return attn_metadata # type: ignore[return-value]
def reorg_kvcache( def reorg_kvcache(
......
...@@ -117,17 +117,20 @@ def maybe_make_prepare_finalize( ...@@ -117,17 +117,20 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. " "Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine." "Falling back to AllGather+ReduceScatter dispatch/combine."
) )
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
assert device_communicator.all2all_manager is not None
return make_moe_prepare_and_finalize_naive_dp_ep( return make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=( num_dispatchers=(device_communicator.all2all_manager.world_size),
get_ep_group().device_communicator.all2all_manager.world_size
),
use_monolithic=use_monolithic, use_monolithic=use_monolithic,
) )
else: else:
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic) return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
all2all_manager = get_ep_group().device_communicator.all2all_manager device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
all2all_manager = device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None prepare_finalize: FusedMoEPrepareAndFinalize | None = None
......
...@@ -7,6 +7,7 @@ from typing import Union ...@@ -7,6 +7,7 @@ from typing import Union
import torch import torch
from vllm.config import ParallelConfig, SchedulerConfig from vllm.config import ParallelConfig, SchedulerConfig
from vllm.config.kernel import MoEBackend
from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
...@@ -1192,7 +1193,7 @@ class FusedMoEConfig: ...@@ -1192,7 +1193,7 @@ class FusedMoEConfig:
# Defaults to intermediate_size_per_partition if not specified. # Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded: int | None = None intermediate_size_per_partition_unpadded: int | None = None
moe_backend: str = "auto" moe_backend: MoEBackend = "auto"
max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
has_bias: bool = False has_bias: bool = False
is_act_and_mul: bool = True is_act_and_mul: bool = True
......
...@@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant( ...@@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant(
DeepGemmQuantScaleFMT.UE8M0, DeepGemmQuantScaleFMT.UE8M0,
] ]
cuda_arch = current_platform.get_device_capability( device_capability = current_platform.get_device_capability(device_id=y.device.index)
device_id=y.device.index assert device_capability is not None
).to_int() cuda_arch = device_capability.to_int()
if current_platform.is_cuda() and cuda_arch >= 80: if current_platform.is_cuda() and cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant( torch.ops._C.persistent_masked_m_silu_mul_quant(
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig, FusedMoEConfig,
...@@ -146,7 +147,7 @@ def backend_to_kernel_cls( ...@@ -146,7 +147,7 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}") raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend.""" """Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = { mapping: dict[str, Mxfp4MoeBackend] = {
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, "flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
...@@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend( ...@@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend(
Select the primary MXFP4 MoE backend. Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
triton_kernels_supported = has_triton_kernels() and ( device_capability = current_platform.get_device_capability()
9, triton_kernels_supported = (
0, has_triton_kernels()
) <= current_platform.get_device_capability() < (11, 0) and device_capability is not None
and (9, 0) <= device_capability < (11, 0)
)
# LoRA: separate experts backend path # LoRA: separate experts backend path
if config.is_lora_enabled: if config.is_lora_enabled:
......
...@@ -4,6 +4,9 @@ import torch ...@@ -4,6 +4,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase,
)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
...@@ -11,12 +14,16 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave ...@@ -11,12 +14,16 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes(): def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
return dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""FlashInfer implementation using the Moe AlltoAll kernel.""" """FlashInfer implementation using the Moe AlltoAll kernel."""
all2all_manager: All2AllManagerBase
def __init__( def __init__(
self, self,
max_num_tokens: int, max_num_tokens: int,
...@@ -32,8 +39,12 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo ...@@ -32,8 +39,12 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager device_communicator = get_ep_group().device_communicator
self.all2all_manager.initialize( assert device_communicator is not None
all2all_manager = device_communicator.all2all_manager
assert all2all_manager is not None
self.all2all_manager = all2all_manager
self.all2all_manager.initialize( # type: ignore[attr-defined]
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
top_k=self.top_k, top_k=self.top_k,
num_experts=self.num_experts, num_experts=self.num_experts,
...@@ -97,7 +108,8 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo ...@@ -97,7 +108,8 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
payloads.append(topk_ids) payloads.append(topk_ids)
payloads.append(topk_weights) payloads.append(topk_weights)
recv_payloads = self.all2all_manager.moe_alltoall.dispatch( assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined]
recv_payloads = self.all2all_manager.moe_alltoall.dispatch( # type: ignore[attr-defined]
token_selected_experts=topk_ids, token_selected_experts=topk_ids,
input_payloads=payloads, input_payloads=payloads,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
...@@ -131,7 +143,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo ...@@ -131,7 +143,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> None:
assert self.all2all_manager.moe_alltoall is not None assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined]
ep_size = self.all2all_manager.world_size ep_size = self.all2all_manager.world_size
hidden_size = fused_expert_output.shape[-1] hidden_size = fused_expert_output.shape[-1]
...@@ -139,7 +151,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo ...@@ -139,7 +151,7 @@ class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeMo
ep_size, self.runtime_max_tokens_per_rank, hidden_size ep_size, self.runtime_max_tokens_per_rank, hidden_size
) )
combined_output = self.all2all_manager.moe_alltoall.combine( combined_output = self.all2all_manager.moe_alltoall.combine( # type: ignore[attr-defined]
payload=fused_expert_output, payload=fused_expert_output,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
) )
......
...@@ -15,19 +15,26 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave ...@@ -15,19 +15,26 @@ from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes(): def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
return dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations.""" """Base class for FlashInfer MoE prepare and finalize operations."""
all2all_manager: All2AllManagerBase
def __init__( def __init__(
self, self,
num_dispatchers: int = 1, num_dispatchers: int = 1,
): ):
super().__init__() super().__init__()
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
assert device_communicator.all2all_manager is not None
self.all2all_manager = device_communicator.all2all_manager
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
...@@ -129,7 +136,7 @@ def flashinfer_alltoall_dispatch( ...@@ -129,7 +136,7 @@ def flashinfer_alltoall_dispatch(
): ):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), ( assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined]
"FlashInfer AllToAll workspace not available" "FlashInfer AllToAll workspace not available"
) )
...@@ -144,7 +151,7 @@ def flashinfer_alltoall_dispatch( ...@@ -144,7 +151,7 @@ def flashinfer_alltoall_dispatch(
topk_ids, topk_ids,
topk_weights, topk_weights,
None, None,
all2all_manager.prepare_workspace_tensor, all2all_manager.prepare_workspace_tensor, # type: ignore[attr-defined]
max_num_token, max_num_token,
ep_rank, ep_rank,
ep_size, ep_size,
...@@ -172,7 +179,7 @@ def flashinfer_alltoall_dispatch( ...@@ -172,7 +179,7 @@ def flashinfer_alltoall_dispatch(
x = MnnvlMoe.mnnvl_moe_alltoallv( x = MnnvlMoe.mnnvl_moe_alltoallv(
x, x,
alltoall_info, alltoall_info,
all2all_manager.workspace_tensor, all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank, ep_rank,
ep_size, ep_size,
) )
...@@ -180,7 +187,7 @@ def flashinfer_alltoall_dispatch( ...@@ -180,7 +187,7 @@ def flashinfer_alltoall_dispatch(
x_sf = MnnvlMoe.mnnvl_moe_alltoallv( x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf, x_sf,
alltoall_info, alltoall_info,
all2all_manager.workspace_tensor, all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank, ep_rank,
ep_size, ep_size,
) )
...@@ -196,7 +203,7 @@ def flashinfer_alltoall_dispatch( ...@@ -196,7 +203,7 @@ def flashinfer_alltoall_dispatch(
x = MnnvlMoe.mnnvl_moe_alltoallv( x = MnnvlMoe.mnnvl_moe_alltoallv(
x, x,
alltoall_info, alltoall_info,
all2all_manager.workspace_tensor, all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank, ep_rank,
ep_size, ep_size,
) )
...@@ -212,13 +219,13 @@ def flashinfer_alltoall_combine( ...@@ -212,13 +219,13 @@ def flashinfer_alltoall_combine(
): ):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert all2all_manager.ensure_alltoall_workspace_initialized(), ( assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined]
"FlashInfer AllToAll workspace not available" "FlashInfer AllToAll workspace not available"
) )
return MnnvlMoe.mnnvl_moe_alltoallv_combine( return MnnvlMoe.mnnvl_moe_alltoallv_combine(
output, output,
alltoall_info, alltoall_info,
all2all_manager.workspace_tensor, all2all_manager.workspace_tensor, # type: ignore[attr-defined]
ep_rank=all2all_manager.rank, ep_rank=all2all_manager.rank,
ep_size=all2all_manager.world_size, ep_size=all2all_manager.world_size,
top_k=top_k, top_k=top_k,
......
...@@ -132,9 +132,11 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular ...@@ -132,9 +132,11 @@ class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular
) )
if scales is None: if scales is None:
assert len(res) == 3
a1q, topk_weights, topk_ids = res a1q, topk_weights, topk_ids = res
a1q_scale = None a1q_scale = None
else: else:
assert len(res) == 4
a1q, topk_weights, topk_ids, scales = res a1q, topk_weights, topk_ids, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
...@@ -217,9 +219,11 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono ...@@ -217,9 +219,11 @@ class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMono
) )
if scales is None: if scales is None:
assert len(res) == 2
a1q, router_logits = res a1q, router_logits = res
a1q_scale = None a1q_scale = None
else: else:
assert len(res) == 3
a1q, router_logits, scales = res a1q, router_logits, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
......
...@@ -54,11 +54,13 @@ class DefaultMoERunner(MoERunnerBase): ...@@ -54,11 +54,13 @@ class DefaultMoERunner(MoERunnerBase):
# NOTE: this will be removed once all kernels are migrated into the # NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework. # MoEKernel framework.
if self.do_naive_dispatch_combine: if self.do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits( res = get_ep_group().dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
self.moe_config.is_sequence_parallel, self.moe_config.is_sequence_parallel,
) )
assert len(res) == 2
hidden_states, router_logits = res
# NOTE: Similar with DP, PCP also needs dispatch and combine. For # NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe # simplicity, AgRsAll2All was added separately for PCP here. Maybe
......
...@@ -16,7 +16,6 @@ from vllm.logger import init_logger ...@@ -16,7 +16,6 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from .fla.ops.kda import ( from .fla.ops.kda import (
...@@ -123,7 +122,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -123,7 +122,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
self.cache_config = cache_config self.cache_config = cache_config
if model_config is None: if model_config is None:
raise ValueError("model_config must be provided") raise ValueError("model_config must be provided")
kda_config = model_config.linear_attn_config kda_config = model_config.linear_attn_config # type: ignore[attr-defined]
self.head_dim = kda_config["head_dim"] self.head_dim = kda_config["head_dim"]
self.num_heads = kda_config["num_heads"] self.num_heads = kda_config["num_heads"]
self.layer_idx = layer_idx self.layer_idx = layer_idx
...@@ -297,19 +296,21 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -297,19 +296,21 @@ class KimiDeltaAttention(nn.Module, MambaBase):
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
) -> None: ) -> None:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata_raw = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata_raw is None:
# # V1 profile run # # V1 profile run
return return
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata_narrowed = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata) assert isinstance(attn_metadata_narrowed, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state has_initial_state = attn_metadata_narrowed.has_initial_state
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_query_start_loc = attn_metadata_narrowed.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = (
num_actual_tokens = attn_metadata.num_actual_tokens attn_metadata_narrowed.non_spec_state_indices_tensor
) # noqa: E501
num_actual_tokens = attn_metadata_narrowed.num_actual_tokens
constant_caches = self.kv_cache constant_caches = self.kv_cache
q_proj_states = q_proj_states[:num_actual_tokens] q_proj_states = q_proj_states[:num_actual_tokens]
...@@ -335,7 +336,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -335,7 +336,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
v_conv_weights = self.v_conv1d.weight.view( v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2) self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
) )
if attn_metadata.num_prefills > 0: if attn_metadata_narrowed.num_prefills > 0:
q_proj_states = q_proj_states.transpose(0, 1) q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1) k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1) v_proj_states = v_proj_states.transpose(0, 1)
...@@ -348,7 +349,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -348,7 +349,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor, cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc, query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata, metadata=attn_metadata_narrowed,
).transpose(0, 1) ).transpose(0, 1)
k = causal_conv1d_fn( k = causal_conv1d_fn(
k_proj_states, k_proj_states,
...@@ -359,7 +360,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -359,7 +360,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor, cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc, query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata, metadata=attn_metadata_narrowed,
).transpose(0, 1) ).transpose(0, 1)
v = causal_conv1d_fn( v = causal_conv1d_fn(
v_proj_states, v_proj_states,
...@@ -370,11 +371,12 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -370,11 +371,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor, cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc, query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata, metadata=attn_metadata_narrowed,
).transpose(0, 1) ).transpose(0, 1)
else: else:
assert non_spec_state_indices_tensor is not None
decode_conv_indices = non_spec_state_indices_tensor[ decode_conv_indices = non_spec_state_indices_tensor[
: attn_metadata.num_actual_tokens : attn_metadata_narrowed.num_actual_tokens
] ]
q = causal_conv1d_update( q = causal_conv1d_update(
q_proj_states, q_proj_states,
...@@ -408,7 +410,9 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -408,7 +410,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
) )
if attn_metadata.num_prefills > 0: if attn_metadata_narrowed.num_prefills > 0:
assert non_spec_state_indices_tensor is not None
assert has_initial_state is not None
zero_idx = non_spec_state_indices_tensor[~has_initial_state] zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0 recurrent_state[zero_idx] = 0
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous() initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
...@@ -429,6 +433,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -429,6 +433,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
# Init cache # Init cache
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
else: else:
assert non_spec_query_start_loc is not None
( (
core_attn_out_non_spec, core_attn_out_non_spec,
last_recurrent_state, last_recurrent_state,
...@@ -440,7 +445,9 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -440,7 +445,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta=beta, beta=beta,
initial_state=recurrent_state, initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], cu_seqlens=non_spec_query_start_loc[
: attn_metadata_narrowed.num_decodes + 1
],
ssm_state_indices=non_spec_state_indices_tensor, ssm_state_indices=non_spec_state_indices_tensor,
) )
core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[ core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[
......
...@@ -76,7 +76,7 @@ def poly_norm( ...@@ -76,7 +76,7 @@ def poly_norm(
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
out = torch.empty_like(x) out = torch.empty_like(x)
ops.poly_norm( ops.poly_norm( # type: ignore[attr-defined]
out, out,
x, x,
weight, weight,
......
...@@ -42,9 +42,10 @@ class MambaBase(AttentionLayerBase): ...@@ -42,9 +42,10 @@ class MambaBase(AttentionLayerBase):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
mamba_block_size = vllm_config.cache_config.mamba_block_size mamba_block_size = vllm_config.cache_config.mamba_block_size
assert mamba_block_size is not None
page_size_padded = vllm_config.cache_config.mamba_page_size_padded page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec( return MambaSpec(
shapes=self.get_state_shape(), shapes=tuple(self.get_state_shape()),
dtypes=self.get_state_dtype(), dtypes=self.get_state_dtype(),
block_size=mamba_block_size, block_size=mamba_block_size,
page_size_padded=page_size_padded, page_size_padded=page_size_padded,
......
...@@ -62,7 +62,6 @@ from vllm.utils.torch_utils import ( ...@@ -62,7 +62,6 @@ from vllm.utils.torch_utils import (
_resolve_layer_name, _resolve_layer_name,
direct_register_custom_op, direct_register_custom_op,
) )
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -121,9 +120,9 @@ def fi_chunk_gated_delta_rule( ...@@ -121,9 +120,9 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp): class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
backend_cfg = get_current_vllm_config().additional_config.get( additional_config = get_current_vllm_config().additional_config
"gdn_prefill_backend", "auto" assert isinstance(additional_config, dict)
) backend_cfg = additional_config.get("gdn_prefill_backend", "auto")
backend = str(backend_cfg).strip().lower() backend = str(backend_cfg).strip().lower()
supports_flashinfer = ( supports_flashinfer = (
...@@ -621,18 +620,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -621,18 +620,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# Part 2: Core Attention # Part 2: Core Attention
# ============================================================ # ============================================================
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata_raw = forward_context.attn_metadata
core_attn_out = torch.zeros( core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
z = torch.empty_like(core_attn_out) z = torch.empty_like(core_attn_out)
if attn_metadata is not None: if attn_metadata_raw is not None:
attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
# TODO: xpu does not support this param yet # TODO: xpu does not support this param yet
spec_sequence_masks = attn_metadata.spec_sequence_masks spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined]
assert spec_sequence_masks is None assert spec_sequence_masks is None
conv_weights = self.conv1d.weight.view( conv_weights = self.conv1d.weight.view(
...@@ -658,12 +658,12 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -658,12 +658,12 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
activation=self.activation, activation=self.activation,
A_log=self.A_log, A_log=self.A_log,
dt_bias=self.dt_bias, dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills, num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
num_decodes=attn_metadata.num_decodes, num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
has_initial_state=attn_metadata.has_initial_state, has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
num_actual_tokens=attn_metadata.num_actual_tokens, num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
tp_size=self.tp_size, tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout, reorder_input=not self.gqa_interleaved_layout,
) )
...@@ -792,16 +792,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -792,16 +792,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata_raw = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata_raw is None:
# V1 profile run — warm up prefill kernels so that # V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation. # autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv) self._warmup_prefill_kernels(mixed_qkv)
return return
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata_raw[self.prefix] # type: ignore[index]
assert isinstance(attn_metadata, GDNAttentionMetadata) assert isinstance(attn_metadata, GDNAttentionMetadata)
if ( if (
...@@ -860,14 +860,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -860,14 +860,16 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# 1.1: Process the multi-query part # 1.1: Process the multi-query part
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
# spec_state_indices_tensor is always set when spec_sequence_masks is set
assert spec_state_indices_tensor is not None
mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec, mixed_qkv_spec,
conv_state, conv_state,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0][ conv_state_indices=spec_state_indices_tensor[:, 0][ # type: ignore[index]
: attn_metadata.num_spec_decodes : attn_metadata.num_spec_decodes # type: ignore[attr-defined]
], ],
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc, query_start_loc=spec_query_start_loc,
...@@ -900,8 +902,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -900,8 +902,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices=non_spec_state_indices_tensor[ conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index]
: attn_metadata.num_actual_tokens : attn_metadata.num_actual_tokens # type: ignore[attr-defined]
], ],
validate_data=True, validate_data=True,
) )
...@@ -965,8 +967,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -965,8 +967,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
v=value_spec, v=value_spec,
initial_state=ssm_state, initial_state=ssm_state,
inplace_final_state=True, inplace_final_state=True,
cu_seqlens=spec_query_start_loc[ cu_seqlens=spec_query_start_loc[ # type: ignore[index]
: attn_metadata.num_spec_decodes + 1 : attn_metadata.num_spec_decodes
+ 1 # type: ignore[attr-defined]
], ],
ssm_state_indices=spec_state_indices_tensor, ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
...@@ -978,8 +981,10 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -978,8 +981,10 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
# 2.2: Process the remaining part # 2.2: Process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() assert non_spec_state_indices_tensor is not None
initial_state[~has_initial_state, ...] = 0 initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() # type: ignore[index]
assert has_initial_state is not None
initial_state[~has_initial_state, ...] = 0 # type: ignore[operator]
( (
core_attn_out_non_spec, core_attn_out_non_spec,
last_recurrent_state, last_recurrent_state,
...@@ -1012,8 +1017,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -1012,8 +1017,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
v=value_non_spec, v=value_non_spec,
initial_state=ssm_state, initial_state=ssm_state,
inplace_final_state=True, inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[ cu_seqlens=non_spec_query_start_loc[ # type: ignore[index]
: attn_metadata.num_decodes + 1 : attn_metadata.num_decodes
+ 1 # type: ignore[attr-defined]
], ],
ssm_state_indices=non_spec_state_indices_tensor, ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
...@@ -1073,7 +1079,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -1073,7 +1079,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
validate_data=False, validate_data=False,
) )
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1) out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
...@@ -1086,7 +1092,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -1086,7 +1092,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
scale=self.head_k_dim**-0.5, scale=self.head_k_dim**-0.5,
initial_state=ssm_state, initial_state=ssm_state,
out=out_buf, out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index]
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
) )
return return
......
...@@ -396,10 +396,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): ...@@ -396,10 +396,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
) -> None: ) -> None:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata_raw = forward_context.attn_metadata
if attn_metadata is not None: attn_metadata: AttentionMetadata | None = None
assert isinstance(attn_metadata, dict) if attn_metadata_raw is not None:
attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata) assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = ( num_actual_tokens = (
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment