Unverified Commit e93f4cc9 authored by Tao He's avatar Tao He Committed by GitHub
Browse files

Add the support for the qwen3 next model (a hybrid attention model). (#24526)


Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 2048c4e3
collect_env.py
vllm/model_executor/layers/fla/ops/*.py
......@@ -403,6 +403,7 @@ th {
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3NextForCausalLM` | Qwen3.5MoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
......
......@@ -228,6 +228,7 @@ fo = "fo"
ba = "ba"
[tool.typos.type.py.extend-words]
ba = "ba"
[tool.typos.type.cpp]
extend-glob = ["*.cu"]
......
......@@ -326,6 +326,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
......@@ -640,7 +642,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online=False),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
}
_TRANSFORMERS_BACKEND_MODELS = {
......
......@@ -1508,7 +1508,8 @@ class ModelConfig:
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
or self.hf_config.model_type == "ernie_mtp"
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
......@@ -1571,15 +1572,28 @@ class ModelConfig:
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
if layers_block_type_value is None and attn_type_list is None:
# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
return sum(t == "full_attention"
for t in layer_types_value[start:end])
elif getattr(block_type, "value",
block_type) == "linear_attention":
return sum(t == "linear_attention"
for t in layer_types_value[start:end])
else:
return sum(t == getattr(block_type, "value", block_type)
for t in layer_types_value[start:end])
if (layers_block_type_value is None and attn_type_list is None
and layer_types_value is None):
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
"layers_block_type or an attn_type_list, or a layer_types "
"in the hf_config, cannot determine the num of "
f"{block_type.value} layers")
return sum(t == 1 for t in attn_type_list[start:end])
def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
......@@ -1866,7 +1880,7 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
"ernie_mtp", "qwen3_next_mtp"]
@config
......@@ -2007,7 +2021,15 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
return hf_config
......@@ -2028,9 +2050,13 @@ class SpeculativeConfig:
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
("mimo","ernie4_5_moe", "qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
......@@ -2140,6 +2166,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
......@@ -2355,7 +2390,8 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
def __repr__(self) -> str:
method = self.method
......
......@@ -341,6 +341,7 @@ class CompilationConfig:
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
]
def compute_hash(self) -> str:
......
......@@ -14,7 +14,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp, safe_exp
from .op import exp
from .utils import is_nvidia_hopper, use_cuda_graph
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
......@@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
boundary_check=(0, 1))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:
......
......@@ -16,7 +16,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp, safe_exp
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
......@@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
......
......@@ -14,7 +14,7 @@ import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import safe_exp
from .op import exp
@triton.heuristics({
......@@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = tl.arange(0, BT)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
......@@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * safe_exp(b_g_diff)
b_A = b_A * exp(b_g_diff)
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
......
......@@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
......
......@@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
......
......@@ -28,11 +28,6 @@ else:
log2 = tl.log2
@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float('-inf')))
if not hasattr(tl, 'gather'):
@triton.jit
......
......@@ -70,6 +70,15 @@ class MambaStateDtypeCalculator:
model_dtype)
return (conv_state_dtype, )
@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
class MambaStateShapeCalculator:
......@@ -163,3 +172,31 @@ class MambaStateShapeCalculator:
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@classmethod
def gated_delta_net_state_shape(
cls,
tp_world_size: int,
num_k_heads: int,
num_v_heads: int,
head_k_dim: int,
head_v_dim: int,
conv_kernel_size: int,
num_spec: int = 0,
use_v1: bool = True,
):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = (
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads,
tp_world_size), head_k_dim, head_v_dim)
return conv_state_shape, temporal_state_shape
......@@ -464,7 +464,9 @@ def causal_conv1d_fn(
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
assert (num_cache_lines, dim, width - 1) == conv_states.shape
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and width - 1 <= conv_states.shape[2])
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
......@@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel(
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
......@@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel(
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
......@@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel(
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
......@@ -663,7 +668,8 @@ def _causal_conv1d_update_kernel(
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
......@@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel(
# not processing as this is not the actual sequence
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
1)
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
......@@ -695,10 +720,14 @@ def _causal_conv1d_update_kernel(
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# The conv_state updates works in a sliding window manner,
# at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
......@@ -820,6 +849,7 @@ def causal_conv1d_update(
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
......@@ -890,10 +920,11 @@ def causal_conv1d_update(
) # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
state_len = width - 1
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
state_len = width - 1 + (seqlen - 1) # effective state_len needed
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
......@@ -910,6 +941,7 @@ def causal_conv1d_update(
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
out,
# Matrix dimensions
batch,
......@@ -926,6 +958,7 @@ def causal_conv1d_update(
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
......@@ -936,6 +969,7 @@ def causal_conv1d_update(
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,
......
......@@ -312,7 +312,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
# TODO(tdoublep): remove as full cuda graph support is added
FCG_NOT_SUPPORTED_MODELS = [
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
"Lfm2ForCausalLM",
"MiniMaxText01ForCausalLM",
]
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer,
Qwen3NextRMSNorm)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class Qwen3NextMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False)
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
config,
layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
) for idx in range(self.num_mtp_layers))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = (spec_step_idx % self.num_mtp_layers)
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, \
"Qwen3NextMTP currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
hidden_states = self.model(input_ids, positions, hidden_states,
intermediate_tensors, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif not any(key in name for key in shared_weight_names):
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))
......@@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
......@@ -285,6 +286,7 @@ _SPECULATIVE_DECODING_MODELS = {
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
......
......@@ -79,7 +79,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
step3_text="Step3TextConfig",
)
qwen3_next="Qwen3NextConfig")
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config",
......
......@@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
......@@ -50,4 +51,5 @@ __all__ = [
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
"Qwen3NextConfig",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment