Commit f3505904 authored by yuanyuan's avatar yuanyuan
Browse files

porting some qwen3.5 fp8 block quant bugfix

parent f28b6574
......@@ -238,3 +238,5 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
vllm/version.py
\ No newline at end of file
......@@ -748,8 +748,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
loaded_shard_id: tuple[int, ...] | int | None = None,
):
if isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
......@@ -781,7 +786,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
......@@ -793,10 +798,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
output_sizes = (
self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
if loaded_shard_id is not None
else self.output_sizes
)
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported "
"for BNB quantization yet."
)
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
......@@ -838,15 +853,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
shard_offset //= self.tp_size
shard_size //= self.tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
......@@ -901,7 +916,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
output_sizes: list[int] | None = None,
):
"""
Handle special case for models where MLP layers are already
......@@ -915,7 +933,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset = 0
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
output_sizes = output_sizes or self.output_sizes
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
......@@ -940,17 +959,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | None = None,
loaded_shard_id: tuple[int, ...] | int | None = None,
):
if loaded_shard_id is None:
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
output_sizes = (
[self.output_sizes[idx] for idx in loaded_shard_id]
if loaded_shard_id
else None
)
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
output_sizes = [
adjust_block_scale_shard(weight_block_size, size, 0)[0]
for size in (output_sizes or self.output_sizes)
]
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
self._load_fused_module_from_checkpoint(
param, loaded_weight, output_sizes=output_sizes
)
return
assert loaded_shard_id < len(self.output_sizes)
......@@ -958,15 +990,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
shard_offset //= self.tp_size
shard_size //= self.tp_size
param.load_merged_column_weight(
loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
......
......@@ -30,44 +30,20 @@ from collections.abc import Callable, Iterable
import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.qwen3_5.configuration_qwen3_5 import (
Qwen3_5Config,
Qwen3_5TextConfig,
)
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeConfig,
Qwen3_5MoeTextConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
ModelConfig,
SpeculativeConfig,
VllmConfig,
get_current_vllm_config,
)
from vllm.distributed import (
divide,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3_5RMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader,
)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
......@@ -81,12 +57,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
sharded_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from transformers.models.qwen3_5.configuration_qwen3_5 import (
Qwen3_5Config,
Qwen3_5TextConfig,
)
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeConfig,
Qwen3_5MoeTextConfig,
)
from .interfaces import (
HasInnerState,
......@@ -138,151 +120,29 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
def __init__(
self,
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
speculative_config: SpeculativeConfig | None = None,
prefix: str = "",
) -> None:
super(Qwen3NextGatedDeltaNet, self).__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj_qkv = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkv",
)
self.in_proj_z = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.value_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_z",
)
self.in_proj_b = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
)
self.in_proj_a = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_v_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_a",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
query_key_settings,
query_key_settings,
value_settings,
],
self.tp_size,
self.tp_rank,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size),
)
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def fix_query_key_value_ordering(
self,
mixed_qkv,
z,
b,
a,
mixed_qkvz: torch.Tensor,
mixed_ba: torch.Tensor,
):
raise NotImplementedError(
"Qwen3.5 Series dont need to fix query key value ordering"
)
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
......@@ -300,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
z, _ = self.in_proj_z(hidden_states)
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b, _ = self.in_proj_b(hidden_states)
a, _ = self.in_proj_a(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
b, a = ba.chunk(2, dim=-1)
b = b.contiguous()
a = a.contiguous()
......@@ -411,7 +273,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
......@@ -419,7 +280,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
......@@ -503,11 +363,18 @@ class Qwen3_5Model(Qwen3NextModel):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
# self attention
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
# mlp
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
("in_proj_ba", "in_proj_b", 0),
("in_proj_ba", "in_proj_a", 1),
]
params_dict = dict(self.named_parameters())
......@@ -528,6 +395,12 @@ class Qwen3_5Model(Qwen3NextModel):
if name.startswith("mtp."):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("scale"):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
......@@ -654,6 +527,9 @@ class Qwen3_5ForCausalLMBase(
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
# GDN fused projections.
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -751,7 +627,15 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
config: Qwen3_5Config = vllm_config.model_config.hf_config
......@@ -868,7 +752,9 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
......@@ -957,7 +843,7 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
class Qwen3_5MoeForConditionalGeneration(
Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
......
......@@ -7,10 +7,6 @@ from collections.abc import Callable, Iterable
import torch
from torch import nn
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
......@@ -27,6 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm
from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_5 import Qwen3_5TextConfig
from vllm.transformers_utils.configs.qwen3_5_moe import Qwen3_5MoeTextConfig
from .interfaces import (
MultiModalEmbeddings,
......
......@@ -40,6 +40,7 @@ from vllm.model_executor.layers.layernorm import (
from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
......@@ -292,19 +293,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
self.projection_size_ba = self.num_v_heads * 2
self.in_proj_qkvz = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.projection_size_qkvz,
bias=False,
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# we need to create qkvz_proj adaptively here.
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
self.in_proj_ba = ColumnParallelLinear(
self.in_proj_ba = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.projection_size_ba,
output_sizes=[self.num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
......@@ -351,7 +352,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
......@@ -368,10 +368,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)
def fix_query_key_value_ordering(
self,
mixed_qkvz,
mixed_ba,
mixed_qkvz: torch.Tensor,
mixed_ba: torch.Tensor,
):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
......@@ -893,7 +909,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
......@@ -901,7 +916,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.dtype,
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment