"tests/models/decoder_only/language/test_granite.py" did not exist on "95248677014cb10a9dbaa2e72f688e1a6e6cf566"
Unverified Commit a608b4c6 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[5/N][Attention] Finish eliminating `vllm/attention` folder (#32064)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f3a2c29
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import cast from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod, UnquantizedLinearMethod,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -33,20 +32,54 @@ from vllm.utils.torch_utils import ( ...@@ -33,20 +32,54 @@ from vllm.utils.torch_utils import (
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionType, AttentionType,
MLAAttentionImpl,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheSpec, KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec, SlidingWindowSpec,
) )
if TYPE_CHECKING:
from vllm.model_executor.layers.attention import MLAAttention
logger = init_logger(__name__) logger = init_logger(__name__)
def validate_kv_sharing_target(
current_layer_name, target_layer_name, static_forward_context
):
error_msg = (
f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} "
)
if current_layer_name == target_layer_name:
raise ValueError(error_msg + "cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg + f"must be the same type as the current layer ({expected})."
)
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool: def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
"""Returns whether the quantization method should load quantized weights.""" """Returns whether the quantization method should load quantized weights."""
return quant_method is not None and not isinstance( return quant_method is not None and not isinstance(
...@@ -493,236 +526,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -493,236 +526,6 @@ class Attention(nn.Module, AttentionLayerBase):
) )
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
):
super().__init__()
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.quant_config = quant_config
# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=self.kv_cache_dtype,
logits_soft_cap=None,
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
**extra_impl_args,
)
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method = (
self.quant_config.get_quant_method(self, prefix=self.layer_name)
if self.quant_config
else None
)
if not should_load_quant_weights(quant_method):
set_default_quant_scales(self, register_buffer=False)
def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range = getattr(self, "q_range", torch.tensor(1.0))
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))
self._q_scale.copy_(torch.abs(q).max() / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
self.calculate_kv_scales = False
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
def maybe_calc_kv_scales( def maybe_calc_kv_scales(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
...@@ -759,7 +562,7 @@ direct_register_custom_op( ...@@ -759,7 +562,7 @@ direct_register_custom_op(
def get_attention_context( def get_attention_context(
layer_name: str, layer_name: str,
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]: ) -> tuple[dict | object | None, "Attention | MLAAttention", torch.Tensor]:
"""Extract attention context for a given layer. """Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer This helper function extracts the attention metadata, attention layer
...@@ -782,7 +585,7 @@ def get_attention_context( ...@@ -782,7 +585,7 @@ def get_attention_context(
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
return attn_metadata, attn_layer, kv_cache return attn_metadata, attn_layer, kv_cache
...@@ -914,79 +717,3 @@ direct_register_custom_op( ...@@ -914,79 +717,3 @@ direct_register_custom_op(
mutates_args=["output", "output_block_scale"], mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
) )
@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
def unified_mla_attention_with_output_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
...@@ -4,9 +4,9 @@ import functools ...@@ -4,9 +4,9 @@ import functools
import torch import torch
from vllm.attention.layer import Attention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
......
...@@ -6,9 +6,9 @@ from copy import copy ...@@ -6,9 +6,9 @@ from copy import copy
import numpy as np import numpy as np
import torch import torch
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
......
...@@ -5,9 +5,9 @@ from copy import copy ...@@ -5,9 +5,9 @@ from copy import copy
import torch import torch
from vllm.attention.layer import Attention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.attention import Attention
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
......
...@@ -19,7 +19,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable: ...@@ -19,7 +19,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
On exit: saves the KV layer to the connector. On exit: saves the KV layer to the connector.
""" """
# Import at runtime to avoid circular dependency # Import at runtime to avoid circular dependency
from vllm.attention.layer import get_attention_context from vllm.model_executor.layers.attention.attention import get_attention_context
# Inspect the signature ONCE when the decorator is applied. # Inspect the signature ONCE when the decorator is applied.
sig = inspect.signature(func) sig = inspect.signature(func)
......
...@@ -191,24 +191,38 @@ import functools ...@@ -191,24 +191,38 @@ import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import ClassVar, Generic, TypeVar from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast
if TYPE_CHECKING:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
import torch import torch
import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.attention.attention import (
vllm_is_batch_invariant, _init_kv_cache_quant,
get_attention_context,
set_default_quant_scales,
should_load_quant_weights,
) )
from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
...@@ -217,11 +231,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -217,11 +231,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
MLAAttentionImpl, MLAAttentionImpl,
) )
...@@ -234,7 +253,320 @@ from vllm.v1.attention.backends.utils import ( ...@@ -234,7 +253,320 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
)
logger = init_logger(__name__)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
This class takes query, and compressed key/value tensors as input.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
kv_b_proj: ColumnParallelLinear,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
):
super().__init__()
self.num_heads = num_heads
self.scale = scale
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.quant_config = quant_config
# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=self.kv_cache_dtype,
logits_soft_cap=None,
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
**extra_impl_args,
)
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.use_sparse = use_sparse
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.impl.forward(
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
)
return output
else:
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method = (
self.quant_config.get_quant_method(self, prefix=self.layer_name)
if self.quant_config
else None
)
if not should_load_quant_weights(quant_method):
set_default_quant_scales(self, register_buffer=False)
def calc_kv_scales(
self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
"""Optional scale calculation for MLA inputs.
Mirrors Attention.calc_kv_scales. Not all MLA backends require this
"""
# Use safe defaults if ranges are not present
q_range = getattr(self, "q_range", torch.tensor(1.0))
k_range = getattr(self, "k_range", torch.tensor(1.0))
v_range = getattr(self, "v_range", torch.tensor(1.0))
self._q_scale.copy_(torch.abs(q).max() / q_range)
# kv_c_normed is the compressed KV representation; use it for k/v
kv_abs_max = torch.abs(kv_c_normed).max()
self._k_scale.copy_(kv_abs_max / k_range)
self._v_scale.copy_(kv_abs_max / v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
self.calculate_kv_scales = False
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
q,
kv_c_normed,
k_pe,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
def unified_mla_attention_with_output_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_mla_attention_with_output",
op_func=unified_mla_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
class QueryLenSupport(Enum): class QueryLenSupport(Enum):
...@@ -266,15 +598,12 @@ except ImportError: ...@@ -266,15 +598,12 @@ except ImportError:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
is_vllm_fa = False is_vllm_fa = False
try:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
flashinfer_available = True @functools.cache
except ImportError: def flashinfer_available() -> bool:
BatchPrefillWithRaggedKVCacheWrapper = object import importlib.util
flashinfer_available = False return importlib.util.find_spec("flashinfer") is not None
def dynamic_per_batched_tensor_quant( def dynamic_per_batched_tensor_quant(
...@@ -398,8 +727,8 @@ class MLACommonPrefillMetadata: ...@@ -398,8 +727,8 @@ class MLACommonPrefillMetadata:
@dataclass @dataclass
class FlashInferPrefillMetadata(MLACommonPrefillMetadata): class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None prefill_main: "BatchPrefillWithRaggedKVCacheWrapper | None" = None
prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field( prefill_chunks: "list[BatchPrefillWithRaggedKVCacheWrapper]" = field(
default_factory=list default_factory=list
) )
...@@ -495,7 +824,7 @@ def use_flashinfer_prefill() -> bool: ...@@ -495,7 +824,7 @@ def use_flashinfer_prefill() -> bool:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if not ( if not (
not vllm_config.attention_config.disable_flashinfer_prefill not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available and flashinfer_available()
and not vllm_config.attention_config.use_cudnn_prefill and not vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability_family(100) and current_platform.is_device_capability_family(100)
): ):
...@@ -509,7 +838,7 @@ def use_cudnn_prefill() -> bool: ...@@ -509,7 +838,7 @@ def use_cudnn_prefill() -> bool:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
return ( return (
flashinfer_available flashinfer_available()
and vllm_config.attention_config.use_cudnn_prefill and vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability_family(100) and current_platform.is_device_capability_family(100)
and has_nvidia_artifactory() and has_nvidia_artifactory()
...@@ -731,6 +1060,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -731,6 +1060,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
has_context = True has_context = True
if self._fi_prefill_main is None: if self._fi_prefill_main is None:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
self._workspace_buffer, "NHD", backend="cutlass" self._workspace_buffer, "NHD", backend="cutlass"
) )
...@@ -739,6 +1070,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -739,6 +1070,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
num_chunks = chunked_context.cu_seq_lens.shape[0] num_chunks = chunked_context.cu_seq_lens.shape[0]
# Allocate more prefill chunk wrappers if needed # Allocate more prefill chunk wrappers if needed
if len(self._fi_prefill_chunks) < num_chunks: if len(self._fi_prefill_chunks) < num_chunks:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
for _ in range(len(self._fi_prefill_chunks), num_chunks): for _ in range(len(self._fi_prefill_chunks), num_chunks):
self._fi_prefill_chunks.append( self._fi_prefill_chunks.append(
BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillWithRaggedKVCacheWrapper(
...@@ -1513,6 +1846,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1513,6 +1846,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
): ):
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
output, lse = cudnn_batch_prefill_with_kv_cache( output, lse = cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
...@@ -1572,6 +1907,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1572,6 +1907,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
return cudnn_batch_prefill_with_kv_cache( return cudnn_batch_prefill_with_kv_cache(
q=q, q=q,
k_cache=k, k_cache=k,
......
...@@ -4,11 +4,11 @@ import functools ...@@ -4,11 +4,11 @@ import functools
import torch import torch
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
......
...@@ -4,9 +4,9 @@ from dataclasses import dataclass ...@@ -4,9 +4,9 @@ from dataclasses import dataclass
import torch import torch
from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
......
...@@ -19,12 +19,12 @@ from compressed_tensors.quantization import ( ...@@ -19,12 +19,12 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig from compressed_tensors.transform import TransformConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
......
...@@ -11,9 +11,9 @@ import vllm.envs as envs ...@@ -11,9 +11,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
......
...@@ -11,8 +11,8 @@ from torch.nn.parameter import Parameter ...@@ -11,8 +11,8 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
......
...@@ -7,9 +7,9 @@ import torch ...@@ -7,9 +7,9 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEConfig, FusedMoEConfig,
......
...@@ -8,8 +8,8 @@ import regex as re ...@@ -8,8 +8,8 @@ import regex as re
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
......
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
......
...@@ -6,8 +6,8 @@ from typing import TYPE_CHECKING, Any, Optional, cast ...@@ -6,8 +6,8 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch import torch
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
......
...@@ -11,9 +11,9 @@ import torch ...@@ -11,9 +11,9 @@ import torch
from torch import nn from torch import nn
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
......
...@@ -9,7 +9,6 @@ from itertools import islice ...@@ -9,7 +9,6 @@ from itertools import islice
import torch import torch
from torch import nn from torch import nn
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
...@@ -18,6 +17,7 @@ from vllm.distributed import ( ...@@ -18,6 +17,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
......
...@@ -11,7 +11,7 @@ import torch.nn as nn ...@@ -11,7 +11,7 @@ import torch.nn as nn
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
......
...@@ -32,12 +32,12 @@ import torch ...@@ -32,12 +32,12 @@ import torch
from torch import nn from torch import nn
from transformers import ApertusConfig from transformers import ApertusConfig
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import XIELU from vllm.model_executor.layers.activation import XIELU
from vllm.model_executor.layers.attention.encoder_only_attention import ( from vllm.model_executor.layers.attention import (
Attention,
EncoderOnlyAttention, EncoderOnlyAttention,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
......
...@@ -8,7 +8,6 @@ from itertools import islice ...@@ -8,7 +8,6 @@ from itertools import islice
import torch import torch
from torch import nn from torch import nn
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
...@@ -19,6 +18,7 @@ from vllm.distributed import ( ...@@ -19,6 +18,7 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment