Commit f3505904 authored by yuanyuan's avatar yuanyuan
Browse files

porting some qwen3.5 fp8 block quant bugfix

parent f28b6574
...@@ -238,3 +238,5 @@ ep_kernels_workspace/ ...@@ -238,3 +238,5 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi vllm/grpc/vllm_engine_pb2.pyi
vllm/version.py
\ No newline at end of file
...@@ -748,8 +748,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -748,8 +748,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None, loaded_shard_id: tuple[int, ...] | int | None = None,
): ):
if isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF # Special case for GGUF
# initialize GGUF param after we know the quantize type # initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
...@@ -781,7 +786,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -781,7 +786,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scale to load scalar into fused array. # Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None: if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
# Loaded weight is already fused on disk (mlp). # Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj). # (e.g., Phi-3's gate_up_proj).
if output_dim is None: if output_dim is None:
...@@ -793,10 +798,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -793,10 +798,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return return
output_sizes = (
self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
if loaded_shard_id is not None
else self.output_sizes
)
current_shard_offset = 0 current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported "
"for BNB quantization yet."
)
shard_offsets: list[tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
...@@ -838,15 +853,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -838,15 +853,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id] shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None) weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard( shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset weight_block_size, shard_size, shard_offset
) )
shard_offset //= self.tp_size
shard_size //= self.tp_size
# Special case for quantization. # Special case for quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
...@@ -901,7 +916,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -901,7 +916,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint( def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
output_sizes: list[int] | None = None,
): ):
""" """
Handle special case for models where MLP layers are already Handle special case for models where MLP layers are already
...@@ -915,7 +933,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -915,7 +933,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset = 0 current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = [] shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): output_sizes = output_sizes or self.output_sizes
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
...@@ -940,17 +959,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -940,17 +959,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, self,
param: BasevLLMParameter, param: BasevLLMParameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None, loaded_shard_id: tuple[int, ...] | int | None = None,
): ):
if loaded_shard_id is None: if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return return
elif type(param) in (RowvLLMParameter, BasevLLMParameter): elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
output_sizes = (
[self.output_sizes[idx] for idx in loaded_shard_id]
if loaded_shard_id
else None
)
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
output_sizes = [
adjust_block_scale_shard(weight_block_size, size, 0)[0]
for size in (output_sizes or self.output_sizes)
]
# TODO: @dsikka - move to parameter.py # TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(
param, loaded_weight, output_sizes=output_sizes
)
return return
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
...@@ -958,15 +990,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -958,15 +990,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id] shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None) weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard( shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset weight_block_size, shard_size, shard_offset
) )
shard_offset //= self.tp_size
shard_size //= self.tp_size
param.load_merged_column_weight( param.load_merged_column_weight(
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
shard_id=loaded_shard_id, shard_id=loaded_shard_id,
......
...@@ -30,44 +30,20 @@ from collections.abc import Callable, Iterable ...@@ -30,44 +30,20 @@ from collections.abc import Callable, Iterable
import torch import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN
from transformers.models.qwen3_5.configuration_qwen3_5 import (
Qwen3_5Config,
Qwen3_5TextConfig,
)
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeConfig,
Qwen3_5MoeTextConfig,
)
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CacheConfig,
ModelConfig,
SpeculativeConfig,
VllmConfig, VllmConfig,
get_current_vllm_config,
) )
from vllm.distributed import ( from vllm.distributed import (
divide,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm, GemmaRMSNorm as Qwen3_5RMSNorm,
) )
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader,
)
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc, MambaStateCopyFunc,
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
...@@ -81,12 +57,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -81,12 +57,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
sharded_weight_loader, maybe_remap_kv_scale_name,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from transformers.models.qwen3_5.configuration_qwen3_5 import (
Qwen3_5Config,
Qwen3_5TextConfig,
)
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeConfig,
Qwen3_5MoeTextConfig,
)
from .interfaces import ( from .interfaces import (
HasInnerState, HasInnerState,
...@@ -138,151 +120,29 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): ...@@ -138,151 +120,29 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def __init__(
self,
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
speculative_config: SpeculativeConfig | None = None,
prefix: str = "",
) -> None:
super(Qwen3NextGatedDeltaNet, self).__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_z",
)
self.in_proj_b = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
)
self.in_proj_a = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_a",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.tp_size,
self.tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size),
)
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkv, mixed_qkvz: torch.Tensor,
z, mixed_ba: torch.Tensor,
b,
a,
): ):
raise NotImplementedError( raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering" "Qwen3.5 Series dont need to fix query key value ordering"
) )
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward( def forward(
self, self,
...@@ -300,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): ...@@ -300,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================ # ============================================================
# Part 1: Input Projection # Part 1: Input Projection
# ============================================================ # ============================================================
mixed_qkv, _ = self.in_proj_qkv(hidden_states) mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
z, _ = self.in_proj_z(hidden_states) qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim) z = z.reshape(z.size(0), -1, self.head_v_dim)
b, _ = self.in_proj_b(hidden_states) ba, _ = self.in_proj_ba(hidden_states)
a, _ = self.in_proj_a(hidden_states) b, a = ba.chunk(2, dim=-1)
b = b.contiguous() b = b.contiguous()
a = a.contiguous() a = a.contiguous()
...@@ -411,7 +273,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): ...@@ -411,7 +273,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1, 1,
1, 1,
config.hidden_size, config.hidden_size,
dtype=config.dtype,
), ),
) )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
...@@ -419,7 +280,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): ...@@ -419,7 +280,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1, 1,
1, 1,
config.hidden_size, config.hidden_size,
dtype=config.dtype,
), ),
) )
...@@ -503,11 +363,18 @@ class Qwen3_5Model(Qwen3NextModel): ...@@ -503,11 +363,18 @@ class Qwen3_5Model(Qwen3NextModel):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
# self attention
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
# mlp
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -528,6 +395,12 @@ class Qwen3_5Model(Qwen3NextModel): ...@@ -528,6 +395,12 @@ class Qwen3_5Model(Qwen3NextModel):
if name.startswith("mtp."): if name.startswith("mtp."):
continue continue
# Remapping the name of FP8 kv-scale.
if name.endswith("scale"):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name: if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True is_fused_expert = True
...@@ -654,6 +527,9 @@ class Qwen3_5ForCausalLMBase( ...@@ -654,6 +527,9 @@ class Qwen3_5ForCausalLMBase(
"v_proj", "v_proj",
], ],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
# GDN fused projections.
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -751,7 +627,15 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts): ...@@ -751,7 +627,15 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder, dummy_inputs=Qwen3VLDummyInputsBuilder,
) )
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__ # protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self) nn.Module.__init__(self)
config: Qwen3_5Config = vllm_config.model_config.hf_config config: Qwen3_5Config = vllm_config.model_config.hf_config
...@@ -868,7 +752,9 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) ...@@ -868,7 +752,9 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]: ) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype( return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
) )
@classmethod @classmethod
...@@ -957,7 +843,7 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts): ...@@ -957,7 +843,7 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
class Qwen3_5MoeForConditionalGeneration( class Qwen3_5MoeForConditionalGeneration(
Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts
): ):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__ # protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self) nn.Module.__init__(self)
config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
......
...@@ -7,10 +7,6 @@ from collections.abc import Callable, Iterable ...@@ -7,10 +7,6 @@ from collections.abc import Callable, Iterable
import torch import torch
from torch import nn from torch import nn
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -27,6 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -27,6 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm
from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_5 import Qwen3_5TextConfig
from vllm.transformers_utils.configs.qwen3_5_moe import Qwen3_5MoeTextConfig
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
......
...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.layernorm import ( ...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.layernorm import (
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
...@@ -292,19 +293,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -292,19 +293,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states # projection of the input hidden states
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 # Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
self.projection_size_ba = self.num_v_heads * 2 # we need to create qkvz_proj adaptively here.
self.in_proj_qkvz = ColumnParallelLinear( self.in_proj_qkvz = self.create_qkvz_proj(
input_size=self.hidden_size, hidden_size=self.hidden_size,
output_size=self.projection_size_qkvz, key_dim=self.key_dim,
bias=False, value_dim=self.value_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz", prefix=f"{prefix}.in_proj_qkvz",
) )
# ba_proj doesn't support blockwise fp8 quantization. # ba_proj doesn't support blockwise fp8 quantization.
self.in_proj_ba = ColumnParallelLinear( self.in_proj_ba = MergedColumnParallelLinear(
input_size=self.hidden_size, input_size=self.hidden_size,
output_size=self.projection_size_ba, output_sizes=[self.num_v_heads] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba", prefix=f"{prefix}.in_proj_ba",
...@@ -351,7 +352,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -351,7 +352,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None, group_size=None,
norm_before_gate=True, norm_before_gate=True,
device=current_platform.current_device(), device=current_platform.current_device(),
dtype=config.dtype,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
...@@ -368,10 +368,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -368,10 +368,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkvz, mixed_qkvz: torch.Tensor,
mixed_ba, mixed_ba: torch.Tensor,
): ):
""" """
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
...@@ -893,7 +909,6 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -893,7 +909,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1, 1,
1, 1,
config.hidden_size, config.hidden_size,
dtype=config.dtype,
), ),
) )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
...@@ -901,7 +916,6 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -901,7 +916,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1, 1,
1, 1,
config.hidden_size, config.hidden_size,
dtype=config.dtype,
), ),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment