Unverified Commit a8eab8f3 authored by Xiaoshuang Wang's avatar Xiaoshuang Wang Committed by GitHub
Browse files

[Model] Extract GatedDeltaNetAttention into shared layer for Qwen3Next and Qwen3.5 (#37975)


Signed-off-by: default avatarwxsIcey <1790571317@qq.com>
Signed-off-by: default avatarIcey <1790571317@qq.com>
parent 2babac0b
This diff is collapsed.
...@@ -28,7 +28,6 @@ import typing ...@@ -28,7 +28,6 @@ import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import torch import torch
from einops import rearrange
from torch import nn from torch import nn
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
...@@ -40,18 +39,14 @@ from vllm.logger import init_logger ...@@ -40,18 +39,14 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm, GemmaRMSNorm as Qwen3_5RMSNorm,
) )
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc, MambaStateCopyFunc,
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -85,7 +80,6 @@ from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP ...@@ -85,7 +80,6 @@ from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from .qwen3_next import ( from .qwen3_next import (
Qwen3NextAttention, Qwen3NextAttention,
Qwen3NextDecoderLayer, Qwen3NextDecoderLayer,
Qwen3NextGatedDeltaNet,
Qwen3NextModel, Qwen3NextModel,
Qwen3NextSparseMoeBlock, Qwen3NextSparseMoeBlock,
QwenNextMixtureOfExperts, QwenNextMixtureOfExperts,
...@@ -121,149 +115,6 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): ...@@ -121,149 +115,6 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
return self.ctx.get_hf_config(Qwen3_5MoeConfig) return self.ctx.get_hf_config(Qwen3_5MoeConfig)
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def fix_query_key_value_ordering(
self,
mixed_qkvz: torch.Tensor,
mixed_ba: torch.Tensor,
):
raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering"
)
def __init__(
self,
config: Qwen3_5Config,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
create_in_proj_qkvz = vllm_config.lora_config is None
super().__init__(
config,
vllm_config=vllm_config,
prefix=prefix,
create_in_proj_qkvz=create_in_proj_qkvz,
)
if vllm_config.lora_config is not None:
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
# output sizes [key_dim, key_dim, value_dim] are not representable
# with a single QKVParallelLinear (which ties K and V head counts).
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.in_proj_z",
)
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
# checkpoint, which are loaded into the fused in_proj_ba parameter
# via stacked_params_mapping with shard_id 0 and 1 respectively.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
if hasattr(self, "in_proj_qkv"):
# LoRA path: separate in_proj_qkv and in_proj_z
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states)
else:
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b, a = ba.chunk(2, dim=-1)
b = b.contiguous()
a = a.contiguous()
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
# Note: we should not use torch.empty here like other attention backends,
# see discussions in https://github.com/vllm-project/vllm/pull/28182
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
self.prefix,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
def __init__( def __init__(
self, self,
...@@ -282,10 +133,12 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): ...@@ -282,10 +133,12 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
if self.layer_type == "linear_attention": if self.layer_type == "linear_attention":
self.linear_attn = Qwen3_5GatedDeltaNet( self.linear_attn = GatedDeltaNetAttention(
config=config, config=config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=f"{prefix}.linear_attn", prefix=f"{prefix}.linear_attn",
gqa_interleaved_layout=False,
create_in_proj_qkvz=vllm_config.lora_config is None,
) )
elif self.layer_type == "full_attention": elif self.layer_type == "full_attention":
self.self_attn = Qwen3NextAttention( self.self_attn = Qwen3NextAttention(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment