"tests/utils_/test_tensor_schema.py" did not exist on "88e46c7c8dfa7c761259795773b758114721169a"
Unverified Commit 9562912c authored by JJJYmmm's avatar JJJYmmm Committed by GitHub
Browse files

[MODEL] Adding Support for Qwen3.5 Models (#34110)


Signed-off-by: default avatarJJJYmmm <1650675829@qq.com>
Signed-off-by: default avatarJJJYmmm <92386084+JJJYmmm@users.noreply.github.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarwulipc <wulipc@users.noreply.github.com>
Co-authored-by: default avatarywang96 <ywang96@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <Isotr0py@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 9bdb06b4
......@@ -738,6 +738,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ |
| `Qwen3_5ForConditionalGeneration` | Qwen3.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3.5-9B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3_5MoeForConditionalGeneration` | Qwen3.5-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3.5-35B-A3B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ |
......
......@@ -970,6 +970,26 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len=4096,
min_transformers_version="4.57",
),
"Qwen3_5ForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct",
max_model_len=4096,
min_transformers_version="5.1.0",
),
"Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct",
max_model_len=4096,
min_transformers_version="5.1.0",
),
"Qwen3_5MTP": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct",
speculative_model="Qwen/Qwen3.5-9B-Instruct",
min_transformers_version="5.1.0",
),
"Qwen3_5MoeMTP": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct",
speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct",
min_transformers_version="5.1.0",
),
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-Omni-30B-A3B-Instruct",
max_model_len=4096,
......
......@@ -1218,8 +1218,8 @@ class ModelConfig:
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
# Hybrid model Qwen3Next Qwen3.5 Series
layer_types_value = getattr(self.hf_text_config, "layer_types", None)
if layer_types_value is not None:
if block_type == "attention":
return sum(
......
......@@ -37,6 +37,7 @@ MTPModelTypes = Literal[
"ernie_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
"qwen3_5_mtp",
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
......@@ -263,6 +264,16 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
)
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
is_moe = hf_config.model_type == "qwen3_5_moe"
hf_config.model_type = "qwen3_5_mtp"
n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
hf_config.update(
{
"n_predict": n_predict,
"architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
}
)
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
......
......@@ -43,7 +43,8 @@ class MambaBase(AttentionLayerBase):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3_5 MTP model."""
import typing
from collections.abc import Callable, Iterable
import torch
from torch import nn
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm
from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts
from vllm.sequence import IntermediateTensors
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
_require_is_multimodal,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
_merge_multimodal_embeddings,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
"hidden_states": 0,
}
)
class Qwen3_5MultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config
self.config = config
self.vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = torch.nn.ModuleList(
Qwen3_5DecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3_5RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_fc_norm_embedding = Qwen3_5RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = spec_step_idx % self.num_mtp_layers
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_fused_expert_weights(
self,
name: str,
params_dict: dict,
loaded_weight: torch.Tensor,
shard_id: str,
num_experts: int,
) -> bool:
param = params_dict[name]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
loaded_local_expert = False
for expert_id in range(num_experts):
curr_expert_weight = loaded_weight[expert_id]
success = weight_loader(
param,
curr_expert_weight,
name,
shard_id,
expert_id,
return_success=True,
)
if success:
loaded_local_expert = True
return loaded_local_expert
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts
if hasattr(self.config, "num_experts")
else 0,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = (
self.config.num_experts if hasattr(self.config, "num_experts") else 0
)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert:
# qwen3.5 no need to transpose
# loaded_weight = loaded_weight.transpose(-1, -2)
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
success_w1 = self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[0],
"w1",
num_experts,
)
success_w3 = self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight[1],
"w3",
num_experts,
)
success = success_w1 and success_w3
else:
# down_proj
success = self.load_fused_expert_weights(
name_mapped,
params_dict,
loaded_weight,
shard_id,
num_experts,
)
if success:
name = name_mapped
break
else:
# Skip loading extra bias for GPTQ models.
if (
name_mapped.endswith(".bias")
or name_mapped.endswith("_bias")
) and name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
weight_loader = param.weight_loader
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
logger.warning_once(
f"Parameter {name} not found in params_dict, skip loading"
)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
"hidden_states": 0,
}
)
class Qwen3_5MTP(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_text_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Qwen3_5MTP currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3_5MultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
is_multimodal = _require_is_multimodal(is_multimodal)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif any(key in name for key in ["embed_tokens", "lm_head"]):
if "embed_tokens" in name:
name = name.replace("language_model.", "")
else:
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))
class Qwen3_5MoeMTP(Qwen3_5MTP, QwenNextMixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.set_moe_parameters()
......@@ -105,7 +105,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
......@@ -176,7 +176,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
renormalize=getattr(config, "norm_topk_prob", True),
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
......@@ -965,7 +965,7 @@ class Qwen3NextModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config
config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
......@@ -1042,7 +1042,7 @@ class Qwen3NextModel(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_experts=getattr(self.config, "num_experts", 0),
num_redundant_experts=self.num_redundant_experts,
)
......@@ -1201,7 +1201,7 @@ class Qwen3NextForCausalLM(
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
......@@ -1265,7 +1265,7 @@ class Qwen3NextForCausalLM(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
hf_config = vllm_config.model_config.hf_text_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
......
......@@ -466,6 +466,14 @@ _MULTIMODAL_MODELS = {
"qwen3_vl_moe",
"Qwen3VLMoeForConditionalGeneration",
),
"Qwen3_5ForConditionalGeneration": (
"qwen3_5",
"Qwen3_5ForConditionalGeneration",
),
"Qwen3_5MoeForConditionalGeneration": (
"qwen3_5",
"Qwen3_5MoeForConditionalGeneration",
),
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
......@@ -509,6 +517,8 @@ _SPECULATIVE_DECODING_MODELS = {
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
"Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
"Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
......
......@@ -420,6 +420,11 @@ class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
class Qwen3_5MTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "mtp_num_hidden_layers", 0)
class PanguUltraMoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_num_hidden_layers(self) -> int:
return getattr(self.hf_text_config, "num_nextn_predict_layers", 0)
......@@ -445,6 +450,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"nemotron-nas": NemotronNasModelArchConfigConvertor,
"deepseek_mtp": DeepSeekMTPModelArchConfigConvertor,
"qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor,
"qwen3_5_mtp": Qwen3_5MTPModelArchConfigConvertor,
"mimo_mtp": MimoMTPModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
......
......@@ -1356,6 +1356,8 @@ class SpecDecodeBaseProposer:
"Qwen3VLMoeForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"GlmOcrForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment