Unverified Commit 1d65283e authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (#34683)

parent c464b573
...@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None, loaded_shard_id: int | None = None,
): ):
if isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF # Special case for GGUF
# initialize GGUF param after we know the quantize type # initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
...@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint( def _load_fused_module_from_checkpoint(
self, self, param: BasevLLMParameter, loaded_weight: torch.Tensor
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
output_sizes: list[int] | None = None,
): ):
""" """
Handle special case for models where MLP layers are already Handle special case for models where MLP layers are already
...@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset = 0 current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
output_sizes = output_sizes or self.output_sizes for i, output_size in enumerate(self.output_sizes):
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
...@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None, loaded_shard_id: int | None = None,
): ):
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return return
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
output_sizes = (
[self.output_sizes[idx] for idx in loaded_shard_id]
if loaded_shard_id
else None
)
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
output_sizes = [
adjust_block_scale_shard(weight_block_size, size, 0)[0]
for size in (output_sizes or self.output_sizes)
]
# TODO: @dsikka - move to parameter.py # TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint( self._load_fused_module_from_checkpoint(param, loaded_weight)
param, loaded_weight, output_sizes=output_sizes
)
return return
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
......
...@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable ...@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable
import torch import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CacheConfig,
ModelConfig,
SpeculativeConfig,
VllmConfig, VllmConfig,
get_current_vllm_config,
) )
from vllm.distributed import ( from vllm.distributed import (
divide,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm, GemmaRMSNorm as Qwen3_5RMSNorm,
) )
from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader,
)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc, MambaStateCopyFunc,
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
...@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
sharded_weight_loader,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_5 import ( from vllm.transformers_utils.configs.qwen3_5 import (
Qwen3_5Config, Qwen3_5Config,
...@@ -80,6 +99,7 @@ from .interfaces import ( ...@@ -80,6 +99,7 @@ from .interfaces import (
) )
from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from .qwen3_next import ( from .qwen3_next import (
ChunkGatedDeltaRule,
Qwen3NextAttention, Qwen3NextAttention,
Qwen3NextDecoderLayer, Qwen3NextDecoderLayer,
Qwen3NextGatedDeltaNet, Qwen3NextGatedDeltaNet,
...@@ -119,29 +139,152 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): ...@@ -119,29 +139,152 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def fix_query_key_value_ordering( def __init__(
self, self,
mixed_qkvz: torch.Tensor, config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
mixed_ba: torch.Tensor, model_config: ModelConfig | None = None,
): cache_config: CacheConfig | None = None,
raise NotImplementedError( quant_config: QuantizationConfig | None = None,
"Qwen3.5 Series dont need to fix query key value ordering" speculative_config: SpeculativeConfig | None = None,
prefix: str = "",
) -> None:
super(Qwen3NextGatedDeltaNet, self).__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
) )
def create_qkvz_proj( # QKV
self, self.conv_dim = self.key_dim * 2 + self.value_dim
hidden_size: int, self.conv1d = ColumnParallelLinear(
key_dim: int, input_size=self.conv_kernel_size,
value_dim: int, output_size=self.conv_dim,
quant_config: QuantizationConfig | None, bias=False,
prefix: str, prefix=f"{prefix}.conv1d",
) -> MergedColumnParallelLinear: )
return MergedColumnParallelLinear( self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim], self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_z",
)
self.in_proj_b = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_b",
)
self.in_proj_a = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_a",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.tp_size,
self.tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size),
)
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def fix_query_key_value_ordering(
self,
mixed_qkv,
z,
b,
a,
):
raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering"
) )
def forward( def forward(
...@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): ...@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================ # ============================================================
# Part 1: Input Projection # Part 1: Input Projection
# ============================================================ # ============================================================
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) mixed_qkv, _ = self.in_proj_qkv(hidden_states)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size z, _ = self.in_proj_z(hidden_states)
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim)
ba, _ = self.in_proj_ba(hidden_states) b, _ = self.in_proj_b(hidden_states)
b, a = ba.chunk(2, dim=-1) a, _ = self.in_proj_a(hidden_states)
b = b.contiguous() b = b.contiguous()
a = a.contiguous() a = a.contiguous()
...@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel): ...@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
# self attention
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
# mlp
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import ( ...@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import (
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
...@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states # projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout, self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
# we need to create qkvz_proj adaptively here. self.projection_size_ba = self.num_v_heads * 2
self.in_proj_qkvz = self.create_qkvz_proj( self.in_proj_qkvz = ColumnParallelLinear(
hidden_size=self.hidden_size, input_size=self.hidden_size,
key_dim=self.key_dim, output_size=self.projection_size_qkvz,
value_dim=self.value_dim, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz", prefix=f"{prefix}.in_proj_qkvz",
) )
# ba_proj doesn't support blockwise fp8 quantization. # ba_proj doesn't support blockwise fp8 quantization.
self.in_proj_ba = MergedColumnParallelLinear( self.in_proj_ba = ColumnParallelLinear(
input_size=self.hidden_size, input_size=self.hidden_size,
output_sizes=[self.num_v_heads] * 2, output_size=self.projection_size_ba,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba", prefix=f"{prefix}.in_proj_ba",
...@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkvz: torch.Tensor, mixed_qkvz,
mixed_ba: torch.Tensor, mixed_ba,
): ):
""" """
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment