Unverified Commit 4e68cc9b authored by Zhiyuan Li's avatar Zhiyuan Li Committed by GitHub
Browse files

[Model] Introduce Kimi Linear to vLLM (#27809)


Signed-off-by: default avatarlizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: default avatarZhiyuan Li <uniartisan2017@gmail.com>
parent 1994de99
......@@ -382,6 +382,7 @@ th {
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |
| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ |
......
......@@ -296,6 +296,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"random": "ai21labs/Jamba-tiny-random",
},
),
"KimiLinearForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-Linear-48B-A3B-Instruct", trust_remote_code=True
),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"),
"Lfm2MoeForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58"
......
......@@ -453,6 +453,7 @@ class CompilationConfig:
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention",
"vllm::kda_attention",
"vllm::sparse_attn_indexer",
]
......
......@@ -1236,6 +1236,7 @@ class ModelConfig:
"deepseek_v32",
"deepseek_mtp",
"kimi_k2",
"kimi_linear",
"longcat_flash",
):
return self.hf_text_config.kv_lora_rank is not None
......
......@@ -1304,7 +1304,7 @@ def kda_gate_fwd_kernel(
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
def kda_gate_fwd(
def fused_kda_gate(
g: torch.Tensor,
A: torch.Tensor,
head_k_dim: int,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from einops import rearrange
from torch import nn
from vllm.attention import AttentionBackend
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from .fla.ops.kda import (
FusedRMSNormGated,
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from .linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from .mamba.abstract import MambaBase
from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .quantization.base_config import QuantizationConfig
logger = init_logger(__name__)
def kda_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
def kda_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="kda_attention",
op_func=kda_attention,
mutates_args=["output"],
fake_impl=kda_attention_fake,
)
class KimiDeltaAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype(
self,
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
if self.model_config is None or self.cache_config is None:
raise ValueError("model_config and cache_config must be set")
return MambaStateDtypeCalculator.kda_state_dtype(
self.model_config.dtype, self.cache_config.mamba_cache_dtype
)
def get_state_shape(
self,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.kda_state_shape(
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
)
def __init__(
self,
layer_idx: int,
hidden_size: int,
quant_config: QuantizationConfig | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
rms_norm_eps: float = 1e-5,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = hidden_size
self.model_config = model_config
self.cache_config = cache_config
if model_config is None:
raise ValueError("model_config must be provided")
kda_config = model_config.linear_attn_config
self.head_dim = kda_config["head_dim"]
self.num_heads = kda_config["num_heads"]
self.layer_idx = layer_idx
self.prefix = prefix
assert self.num_heads % self.tp_size == 0
self.local_num_heads = divide(self.num_heads, self.tp_size)
projection_size = self.head_dim * self.num_heads
self.conv_size = kda_config["short_conv_kernel_size"]
self.q_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)
self.f_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_a_proj",
)
self.f_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_b_proj",
)
self.dt_bias = nn.Parameter(
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
)
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.b_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.b_proj",
)
self.q_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.q_conv1d",
)
self.k_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.k_conv1d",
)
self.v_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.v_conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
self.A_log = nn.Parameter(
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
self.g_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_a_proj",
)
self.g_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_b_proj",
)
self.o_norm = FusedRMSNormGated(
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
)
self.o_proj = RowParallelLinear(
projection_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
self.prefix,
)
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
constant_caches = self.kv_cache[forward_context.virtual_engine]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
k_conv_weights = self.k_conv1d.weight.view(
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
)
v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
)
if attn_metadata.num_prefills > 0:
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
self.q_conv1d.bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
self.k_conv1d.bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
self.v_conv1d.bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
else:
decode_conv_indices = non_spec_state_indices_tensor[
: attn_metadata.num_decodes
]
q = causal_conv1d_update(
q_proj_states,
conv_state_q,
q_conv_weights,
self.q_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
k = causal_conv1d_update(
k_proj_states,
conv_state_k,
k_conv_weights,
self.k_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
v = causal_conv1d_update(
v_proj_states,
conv_state_v,
v_conv_weights,
self.v_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
)
# Init cache
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
else:
(
core_attn_out_non_spec,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
......@@ -80,6 +80,15 @@ class MambaStateDtypeCalculator:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
@classmethod
def kda_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
):
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype, state_dtype, torch.float32)
class MambaStateShapeCalculator:
@classmethod
......@@ -182,3 +191,35 @@ class MambaStateShapeCalculator:
head_v_dim,
)
return conv_state_shape, temporal_state_shape
@classmethod
def kda_state_shape(
cls,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: int | None = None,
head_k_dim: int | None = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return (
conv_state_shape,
conv_state_k_shape,
conv_state_k_shape,
recurrent_state_shape,
)
......@@ -147,9 +147,10 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
......@@ -8,7 +9,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
......@@ -347,12 +348,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
......@@ -372,17 +389,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0:
return
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
......@@ -400,15 +406,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# easily by changing the way we layout chunks in the
# mamba2 kernels.
from math import gcd
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
......
This diff is collapsed.
......@@ -118,6 +118,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
"Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
......
......@@ -79,6 +79,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
......
......@@ -19,6 +19,7 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
......@@ -54,6 +55,7 @@ __all__ = [
"MiDashengLMConfig",
"MLPSpeculatorConfig",
"MoonViTConfig",
"KimiLinearConfig",
"KimiVLConfig",
"NemotronConfig",
"NemotronHConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: int | None = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: int | None = None,
num_experts_per_token: int | None = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: int | None = None,
kv_lora_rank: int | None = None,
qk_nope_head_dim: int | None = None,
qk_rope_head_dim: int | None = None,
v_head_dim: int | None = None,
mla_use_nope: bool | None = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: dict | None = None,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)
......@@ -8,6 +8,7 @@ from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
......@@ -4134,26 +4135,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering
is compatible (e.g., decode threshold is the same)
Choose the minimum reorder batch threshold from all attention groups.
Backends should be able to support lower threshold then what they request
just may have a performance penalty due to that backend treating decodes
as prefills.
"""
for group in self._attn_group_iterator():
attn_metadata_builder_i = group.get_metadata_builder()
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold
if reorder_batch_threshold_i is not None:
if self.reorder_batch_threshold is not None:
if reorder_batch_threshold_i != self.reorder_batch_threshold:
raise ValueError(
f"Attention backend reorders decodes with "
f"threshold {reorder_batch_threshold_i} but other "
f"backend uses threshold "
f"{self.reorder_batch_threshold}"
)
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)
reorder_batch_thresholds = [
group.get_metadata_builder().reorder_batch_threshold
for group in self._attn_group_iterator()
]
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
def _find_compatible_block_sizes(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment