Unverified Commit 7377131a authored by Tao He's avatar Tao He Committed by GitHub
Browse files

[Qwen3] Enable dual-chunk-attention support for Qwen3 models. (#21924)


Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
parent 6b47ef24
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights.""" """Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -47,27 +47,31 @@ from vllm.sequence import IntermediateTensors ...@@ -47,27 +47,31 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
def __init__(self, def __init__(
hidden_size: int, self,
num_heads: int, hidden_size: int,
num_kv_heads: int, num_heads: int,
max_position: int = 4096 * 32, num_kv_heads: int,
head_dim: Optional[int] = None, max_position: int = 4096 * 32,
rms_norm_eps: float = 1e-06, head_dim: Optional[int] = None,
qkv_bias: bool = False, rms_norm_eps: float = 1e-06,
rope_theta: float = 10000, qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None,
rope_scaling: Optional[tuple] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", rope_scaling: Optional[tuple] = None,
attn_type: str = AttentionType.DECODER) -> None: prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -89,6 +93,7 @@ class Qwen3Attention(nn.Module): ...@@ -89,6 +93,7 @@ class Qwen3Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
...@@ -113,15 +118,22 @@ class Qwen3Attention(nn.Module): ...@@ -113,15 +118,22 @@ class Qwen3Attention(nn.Module):
max_position=max_position, max_position=max_position,
base=self.rope_theta, base=self.rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type,
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
) )
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
...@@ -161,6 +173,9 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -161,6 +173,9 @@ class Qwen3DecoderLayer(nn.Module):
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen3 uses causal attention as it is a decoder-only model. # By default, Qwen3 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable # You can override the HF config with `is_causal=False` to enable
...@@ -185,6 +200,7 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -185,6 +200,7 @@ class Qwen3DecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_type=attn_type, attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.mlp = Qwen3MLP( self.mlp = Qwen3MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
...@@ -185,6 +185,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -185,6 +185,7 @@ class Qwen3MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -208,6 +209,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -208,6 +209,7 @@ class Qwen3MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(hidden_size, self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim, self.head_dim,
...@@ -229,14 +231,21 @@ class Qwen3MoeAttention(nn.Module): ...@@ -229,14 +231,21 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
) )
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
...@@ -280,6 +289,9 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -280,6 +289,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
self.self_attn = Qwen3MoeAttention( self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -293,6 +305,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -293,6 +305,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
) )
# `mlp_only_layers` in the config. # `mlp_only_layers` in the config.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment