Unverified Commit e93f4cc9 authored by Tao He's avatar Tao He Committed by GitHub
Browse files

Add the support for the qwen3 next model (a hybrid attention model). (#24526)


Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 2048c4e3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3-Next model configuration"""
from transformers.configuration_utils import (PretrainedConfig,
layer_type_validation)
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen3NextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
Qwen3-Next model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
`inputs_ids`.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
head_dim (`int`, *optional*, defaults to 256):
Projection weights dimension in multi-head attention.
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
Kernel size of the convolution used in linear attention layers.
linear_key_head_dim (`int`, *optional*, defaults to 128):
Dimension of each key head in linear attention.
linear_value_head_dim (`int`, *optional*, defaults to 128):
Dimension of each value head in linear attention.
linear_num_key_heads (`int`, *optional*, defaults to 16):
Number of key heads used in linear attention layers.
linear_num_value_heads (`int`, *optional*, defaults to 32):
Number of value heads used in linear attention layers.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 10):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 512):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
layer_types (`list[str]`, *optional*):
Types of each layer (attention or linear).
```python
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
>>> # Initializing a Qwen3Next style configuration
>>> configuration = Qwen3NextConfig()
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
>>> model = Qwen3NextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
""" # noqa: E501
model_type = "qwen3_next"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=48,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.25,
attention_bias=False,
attention_dropout=0.0,
head_dim=256,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=32,
decoder_sparse_step=1,
moe_intermediate_size=512,
shared_expert_intermediate_size=512,
num_experts_per_tok=10,
num_experts=512,
norm_topk_prob=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
layer_types=None,
**kwargs,
):
if mlp_only_layers is None:
mlp_only_layers = []
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"linear_attention" if bool((i + 1) % 4) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
# linear attention part
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = mlp_only_layers
__all__ = ["Qwen3NextConfig"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
has_initial_state: Optional[torch.Tensor] = None
spec_query_start_loc: Optional[
torch.Tensor] = None # shape: [num_spec_decodes + 1,]
non_spec_query_start_loc: Optional[
torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,]
spec_state_indices_tensor: Optional[
torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[
torch.Tensor] = None # shape: [batch - num_spec_decodes,]
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[
torch.
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
class GDNAttentionMetadataBuilder(
AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc]
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.bool,
device=device,
)
self.spec_token_masks = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1), ),
dtype=torch.bool,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1, ),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1, ),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_draft_tokens: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens
if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0):
spec_sequence_masks = None
else:
spec_sequence_masks = (num_draft_tokens > 0) & (
context_lens_tensor +
(num_draft_tokens + 1) == seq_lens_tensor)
if spec_sequence_masks.sum().item() == 0:
spec_sequence_masks = None
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1))
num_spec_decodes = 0
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
num_spec_decodes = spec_sequence_masks.sum().item()
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item(
) - num_decode_tokens
if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
(min(num_spec_decodes *
(self.num_spec + 1), query_start_loc[-1].item())),
dtype=torch.bool,
device=query_start_loc.device)
spec_state_indices_tensor = m.block_table_tensor[:, :self.
num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens)
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, :self.num_spec + 1]
non_spec_state_indices_tensor = \
m.block_table_tensor[~spec_sequence_masks, 0]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device)
torch.cumsum(query_lens[spec_sequence_masks],
dim=0,
out=spec_query_start_loc[1:])
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device)
torch.cumsum(query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:])
num_spec_decode_tokens = min(
num_spec_decodes * (self.num_spec + 1),
spec_token_masks.size(0))
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
else:
has_initial_state = None
# prepare tensors for cudagraph
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens // (self.num_spec + 1)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True)
spec_state_indices_tensor = self.spec_state_indices_tensor[:
batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert spec_token_masks is not None
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True)
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
spec_token_masks[spec_token_masks.size(0):].fill_(False)
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True)
spec_num_query_tokens = spec_query_start_loc[
-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1]
spec_query_start_loc[num_spec_decodes +
1:].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (self.use_full_cuda_graph and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True)
non_spec_state_indices_tensor = \
self.non_spec_state_indices_tensor[:batch_size]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[:num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True)
non_spec_num_query_tokens = non_spec_query_start_loc[
-1] # type: ignore[index]
non_spec_query_start_loc = \
self.non_spec_query_start_loc[:batch_size + 1]
non_spec_query_start_loc[num_decodes +
1:].fill_(non_spec_num_query_tokens)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
and ((m.num_reqs + 1) * (self.num_spec + 1)
>= m.num_actual_tokens)), \
"GDN only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_accepted_tokens = torch.full((m.num_reqs, ),
m.max_query_len,
dtype=torch.int32,
device=m.query_start_loc.device)
num_drafted_tokens = torch.full((m.num_reqs, ),
self.num_spec,
dtype=torch.int32,
device=m.query_start_loc.device)
# Fixes query-start loc for spec-sequence-indices.
m.query_start_loc = torch.arange(0,
m.num_actual_tokens + 1,
step=m.max_query_len,
device=m.query_start_loc.device,
dtype=torch.int32)
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
...@@ -559,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager): ...@@ -559,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager):
num_running_requests: int) -> int: num_running_requests: int) -> int:
return 0 return 0
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks
"""
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (self.kv_cache_spec.block_size *
self.kv_cache_spec.num_speculative_blocks)
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
len(self.req_to_blocks[request_id]))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return num_new_blocks + num_evictable_computed_blocks
def allocate_new_blocks(self, request_id: str, def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]: num_tokens: int) -> list[KVCacheBlock]:
new_blocks = super().allocate_new_blocks(request_id, num_tokens) # Allocate extra `num_speculative_blocks` blocks for
assert len(self.req_to_blocks[request_id]) == 1, ( # speculative decoding (MTP/EAGLE) with linear attention.
"MambaManager should only allocate 1 block for each request.") assert isinstance(self.kv_cache_spec, MambaSpec)
return new_blocks if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (self.kv_cache_spec.block_size *
self.kv_cache_spec.num_speculative_blocks)
return super().allocate_new_blocks(request_id, num_tokens)
class CrossAttentionManager(SingleTypeKVCacheManager): class CrossAttentionManager(SingleTypeKVCacheManager):
......
...@@ -194,6 +194,7 @@ class MambaSpec(KVCacheSpec): ...@@ -194,6 +194,7 @@ class MambaSpec(KVCacheSpec):
dtypes: tuple[torch.dtype] dtypes: tuple[torch.dtype]
page_size_padded: Optional[int] = None page_size_padded: Optional[int] = None
mamba_type: str = "mamba2" mamba_type: str = "mamba2"
num_speculative_blocks: int = 0
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
......
...@@ -218,7 +218,7 @@ class EagleProposer: ...@@ -218,7 +218,7 @@ class EagleProposer:
hidden_states=self.hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if self.method in ("deepseek_mtp", "ernie_mtp"): if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"):
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states hidden_states = last_hidden_states
else: else:
...@@ -322,12 +322,18 @@ class EagleProposer: ...@@ -322,12 +322,18 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model( ret_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.positions[:input_batch_size], positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if self.method in ("deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp"):
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
None) None)
......
...@@ -156,9 +156,14 @@ class BlockTable: ...@@ -156,9 +156,14 @@ class BlockTable:
class MultiGroupBlockTable: class MultiGroupBlockTable:
"""The BlockTables for each KV cache group.""" """The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int, def __init__(self,
max_num_batched_tokens: int, pin_memory: bool, max_num_reqs: int,
device: torch.device, block_sizes: list[int]) -> None: max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0) -> None:
# Note(hc): each dcp rank only store # Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache, # (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req # so the block_size which used for calc max_num_blocks_per_req
...@@ -170,10 +175,11 @@ class MultiGroupBlockTable: ...@@ -170,10 +175,11 @@ class MultiGroupBlockTable:
dcp_world_size = 1 dcp_world_size = 1
self.block_tables = [ self.block_tables = [
BlockTable(block_size, max_num_reqs, BlockTable(
cdiv(max_model_len, block_size * dcp_world_size), block_size, max_num_reqs,
max_num_batched_tokens, pin_memory, device) max(cdiv(max_model_len, block_size * dcp_world_size),
for block_size in block_sizes 1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device) for block_size in block_sizes
] ]
def append_row(self, block_ids: tuple[list[int], ...], def append_row(self, block_ids: tuple[list[int], ...],
......
...@@ -83,6 +83,7 @@ class InputBatch: ...@@ -83,6 +83,7 @@ class InputBatch:
logitsprocs: Optional[LogitsProcessors] = None, logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
): ):
self.is_pooling_model = is_pooling_model self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
...@@ -127,6 +128,7 @@ class InputBatch: ...@@ -127,6 +128,7 @@ class InputBatch:
pin_memory=pin_memory, pin_memory=pin_memory,
device=device, device=device,
block_sizes=block_sizes, block_sizes=block_sizes,
num_speculative_tokens=num_speculative_tokens,
) )
# Sampling-related. # Sampling-related.
...@@ -202,6 +204,14 @@ class InputBatch: ...@@ -202,6 +204,14 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set() self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \
self.num_accepted_tokens_cpu_tensor.numpy()
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32) dtype=np.int32)
...@@ -394,6 +404,9 @@ class InputBatch: ...@@ -394,6 +404,9 @@ class InputBatch:
else: else:
raise NotImplementedError("Unrecognized request type") raise NotImplementedError("Unrecognized request type")
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
lora_id = request.lora_request.lora_int_id lora_id = request.lora_request.lora_int_id
...@@ -515,6 +528,8 @@ class InputBatch: ...@@ -515,6 +528,8 @@ class InputBatch:
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
swap_dict_values(self.generators, i1, i2) swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2)
...@@ -609,6 +624,8 @@ class InputBatch: ...@@ -609,6 +624,8 @@ class InputBatch:
empty_index] = self.presence_penalties_cpu[last_req_index] empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[ self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index] empty_index] = self.repetition_penalties_cpu[last_req_index]
self.num_accepted_tokens_cpu[
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None) generator = self.generators.pop(last_req_index, None)
if generator is not None: if generator is not None:
self.generators[empty_index] = generator self.generators[empty_index] = generator
......
...@@ -53,9 +53,9 @@ from vllm.sampling_params import SamplingType ...@@ -53,9 +53,9 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
get_dtype_size, is_pin_memory_available, round_up, is_pin_memory_available, round_up, supports_dynamo)
supports_dynamo) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
...@@ -324,6 +324,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -324,6 +324,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size, self.hidden_size,
dtype=self.dtype, dtype=self.dtype,
numpy=False) numpy=False)
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
...@@ -663,6 +667,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -663,6 +667,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata() self.input_batch.refresh_metadata()
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor) -> None:
"""Update the cached states after model execution.
This is used for MTP/EAGLE for hybrid models, as in linear attention,
only the last token's state is kept. In MTP/EAGLE, for draft tokens
the state are kept util we decide how many tokens are accepted for
each sequence, and a shifting is done during the next iteration
based on the number of accepted tokens.
"""
if not self.model_config.is_hybrid or not self.speculative_config:
return
# Find the number of accepted tokens for each sequence.
num_accepted_tokens = (torch.cat(
[
output_token_ids,
torch.full((output_token_ids.size(0), 1),
-1,
device=output_token_ids.device),
],
dim=1) == -1).int().argmax(-1).cpu().numpy()
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
def _init_mrope_positions(self, req_state: CachedRequestState): def _init_mrope_positions(self, req_state: CachedRequestState):
image_grid_thw = [] image_grid_thw = []
video_grid_thw = [] video_grid_thw = []
...@@ -936,6 +965,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -936,6 +965,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We will ignore the sampled tokens from the partial requests. # We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None
spec_decode_metadata = None spec_decode_metadata = None
else: else:
# Get the number of draft tokens for each request. # Get the number of draft tokens for each request.
...@@ -950,6 +980,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -950,6 +980,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata = self._calc_spec_decode_metadata( spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens) num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
logits_indices_padded = None logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill: if self.cache_config.kv_sharing_fast_prefill:
...@@ -964,6 +997,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -964,6 +997,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens_cpu = ( num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None spec_decode_common_attn_metadata = None
if use_spec_decode:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# Prepare the attention metadata for each KV cache group and make layers # Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata. # in the same group share the same metadata.
...@@ -1034,10 +1072,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1034,10 +1072,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
builder, builder,
) )
extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder,
GDNAttentionMetadataBuilder):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
)
attn_metadata_i = builder.build( attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
) **extra_attn_metadata_args)
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
...@@ -1814,6 +1861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1814,6 +1861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata, sampling_metadata,
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
return sampler_output return sampler_output
...@@ -2644,13 +2692,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2644,13 +2692,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Note: Overriding max_query_len to be the prefill tokens # Note: Overriding max_query_len to be the prefill tokens
max_query_len = num_prefill_tokens max_query_len = num_prefill_tokens
elif uniform_decode: elif uniform_decode:
assert not create_mixed_batch num_reqs = num_tokens // max_query_len
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \ assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch" "Do not capture num_reqs > max_num_reqs for uniform batch"
num_scheduled_tokens_list = [max_query_len] * num_reqs num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0: if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len num_scheduled_tokens_list[-1] += num_tokens % max_query_len
else: else:
num_reqs = min(num_tokens, max_num_reqs) num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
...@@ -3297,6 +3344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3297,6 +3344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs, logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0),
) )
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(
...@@ -3647,7 +3697,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3647,7 +3697,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0: if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None: if (self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError( raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.") "Mamba with speculative decoding is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching: if self.vllm_config.cache_config.enable_prefix_caching:
...@@ -3666,7 +3718,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3666,7 +3718,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtypes=mamba_module.get_state_dtype(), dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len, block_size=max_model_len,
page_size_padded=page_size_padded, page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type) mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0),
)
return kv_cache_spec return kv_cache_spec
......
...@@ -78,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -78,7 +78,8 @@ class Worker(LocalOrDistributedWorkerBase):
"deepseek_mtp", "deepseek_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"mimo_mtp", "mimo_mtp",
"ernie_mtp")) \ "ernie_mtp",
"qwen3_next_mtp")) \
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment