Unverified Commit e7a1387e authored by Kyungmin Lee's avatar Kyungmin Lee Committed by GitHub
Browse files

Add EXAONE-4.5 (#39388)


Signed-off-by: default avatarlkm2835 <lkm2835@gmail.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent f83de719
......@@ -550,6 +550,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekOCR2ForCausalLM` | DeepSeek-OCR-2 | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR-2`, etc. | ✅︎ | ✅︎ |
| `Eagle2_5_VLForConditionalGeneration` | Eagle2.5-VL | T + I<sup>E+</sup> | `nvidia/Eagle2.5-8B`, etc. | ✅︎ | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `Exaone4_5_ForConditionalGeneration` | EXAONE-4.5 | T + I<sup>E+</sup> | `LGAI-EXAONE/EXAONE-4.5-33B`, etc. | ✅︎ | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
......
......@@ -421,6 +421,43 @@ def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
)
# EXAONE-4.5
def run_exaone4_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "LGAI-EXAONE/EXAONE-4.5-33B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompts = [
(
"<|system|>\nYou are a helpful assistant.<|endofturn|>\n"
f"<|user|>\n<vision>{placeholder}</vision>"
f"{question}<|endofturn|>\n"
"<|assistant|>\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Fuyu
def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -2199,6 +2236,7 @@ model_example_map = {
"dots_ocr": run_dots_ocr,
"eagle2_5": run_eagle2_5,
"ernie45_vl": run_ernie45_vl,
"exaone4_5": run_exaone4_5,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
"gemma3n": run_gemma3n,
......
......@@ -241,6 +241,41 @@ def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
)
# exaone4_5
def load_exaone4_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "LGAI-EXAONE/EXAONE-4.5-33B"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
......@@ -1450,6 +1485,7 @@ model_example_map = {
"command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"exaone4_5": load_exaone4_5,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"hunyuan_vl": load_hunyuan_vl,
......
......@@ -813,6 +813,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True,
revision="refs/pr/17",
),
"Exaone4_5_ForConditionalGeneration": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-4.5-33B",
min_transformers_version="5.6.0",
),
"FireRedASR2ForConditionalGeneration": _HfExamplesInfo(
"allendou/FireRedASR2-LLM-vllm",
),
......@@ -1306,6 +1310,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
min_transformers_version="5.1.0",
enable_prefix_caching=False,
),
"Exaone4_5_MTP": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-4.5-33B",
speculative_model="LGAI-EXAONE/EXAONE-4.5-33B",
min_transformers_version="5.6.0",
),
"ExtractHiddenStatesModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",
speculative_method="extract_hidden_states",
......
......@@ -40,6 +40,7 @@ MTPModelTypes = Literal[
"ernie_mtp",
"nemotron_h_mtp",
"exaone_moe_mtp",
"exaone4_5_mtp",
"qwen3_next_mtp",
"qwen3_5_mtp",
"longcat_flash_mtp",
......@@ -327,7 +328,13 @@ class SpeculativeConfig:
hf_config.update(
{"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
)
if "exaone4_5" in hf_config.model_type:
hf_config.model_type = "exaone4_5_mtp"
if hf_config.model_type == "exaone4_5_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["Exaone4_5_MTP"]}
)
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
is_moe = hf_config.model_type == "qwen3_5_moe"
hf_config.model_type = "qwen3_5_mtp"
......
......@@ -75,6 +75,7 @@ class Exaone4GatedMLP(nn.Module):
reduce_results: bool = True,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -83,6 +84,7 @@ class Exaone4GatedMLP(nn.Module):
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
......@@ -91,6 +93,7 @@ class Exaone4GatedMLP(nn.Module):
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
if hidden_act != "silu":
raise ValueError(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only EXAONE-4.5 model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable
from functools import partial
import einops
import torch
import torch.nn as nn
from transformers.models.exaone4_5 import (
Exaone4_5_Config,
Exaone4_5_ImageProcessor,
Exaone4_5_Processor,
)
from transformers.models.exaone4_5.configuration_exaone4_5 import Exaone4_5_VisionConfig
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.models.exaone4 import Exaone4GatedMLP as Exaone4_5_VisionMLP
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLProcessingInfo,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Exaone4_5_DummyInputsBuilder
from .qwen2_vl import Qwen2VLMultiModalProcessor as Exaone4_5_MultiModalProcessor
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
logger = init_logger(__name__)
# === Vision Encoder === #
class EXAONE4_5_VisionAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size
)
self.total_num_heads = num_heads
self.total_num_kv_heads = num_kv_heads
self.num_heads = num_heads // self.tp_size
self.num_kv_heads = max(1, num_kv_heads // self.tp_size)
self.head_dim = embed_dim // num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel,
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
num_kv_heads=self.num_kv_heads,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# qkv: [s, b, (h + 2*hk) * d]
s, b, _ = qkv.shape
h = self.num_heads
hk = self.num_kv_heads
d = self.head_dim
qkv = qkv.view(s, b, h + 2 * hk, d)
q = qkv[:, :, :h, :]
k = qkv[:, :, h : h + hk, :]
v = qkv[:, :, h + hk :, :]
# [s, b, h, d] -> [b, s, h, d]
return (
q.permute(1, 0, 2, 3).contiguous(),
k.permute(1, 0, 2, 3).contiguous(),
v.permute(1, 0, 2, 3).contiguous(),
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
seq_len, batch_size, _ = x.shape
q, k, v = self.split_qkv(x)
q = self.apply_rotary_emb(
q,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
k = self.apply_rotary_emb(
k,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = einops.rearrange(
context_layer, "b s h d -> s b (h d)", b=batch_size
).contiguous()
output, _ = self.proj(context_layer)
return output
@support_torch_compile(
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class Exaone4_5_VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
num_kv_heads: int,
mlp_hidden_dim: int,
hidden_act: str = "silu",
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = EXAONE4_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
)
self.mlp = Exaone4_5_VisionMLP(
dim,
mlp_hidden_dim,
hidden_act=hidden_act,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
return x
class EXAONE4_5_VisionTransformer(Qwen2_5_VisionTransformer):
def __init__(
self,
vision_config: Exaone4_5_VisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__(
vision_config=vision_config,
norm_eps=norm_eps,
quant_config=quant_config,
prefix=prefix,
)
depth = vision_config.depth
self.num_kv_heads = vision_config.num_key_value_heads
norm_layer = partial(RMSNorm, eps=norm_eps)
self.blocks = nn.ModuleList(
[
Exaone4_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
mlp_hidden_dim=vision_config.intermediate_size,
hidden_act=vision_config.hidden_act,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(depth)
]
)
class Exaone4_5_ProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Exaone4_5_Config)
def get_hf_processor(self, **kwargs: object) -> Exaone4_5_Processor:
return self.ctx.get_hf_processor(
Exaone4_5_Processor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def get_image_processor(self, **kwargs: object) -> Exaone4_5_ImageProcessor:
return Exaone4_5_ImageProcessor(**kwargs)
@MULTIMODAL_REGISTRY.register_processor(
Exaone4_5_MultiModalProcessor,
info=Exaone4_5_ProcessingInfo,
dummy_inputs=Exaone4_5_DummyInputsBuilder,
)
class Exaone4_5_ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config: Exaone4_5_Config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.multimodal_config = multimodal_config
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = EXAONE4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
hf_config=config.get_text_config(),
architectures=["Exaone4ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["mtp."]),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<vision><|image_pad|></vision>"
if modality.startswith("video"):
return "<vision><|video_pad|></vision>"
raise ValueError("Only image or video modality is supported")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only EXAONE-4_5 MTP model."""
from collections.abc import Iterable
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.models.exaone4 import Exaone4DecoderLayer
from vllm.model_executor.models.exaone_moe_mtp import (
ExaoneMoeMTP,
ExaoneMoeMultiTokenPredictor,
)
from .utils import (
AutoWeightsLoader,
maybe_prefix,
)
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class Exaone4_5MultiTokenPredictor(ExaoneMoeMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config = model_config.hf_config
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = nn.ModuleList(
Exaone4DecoderLayer(
vllm_config.model_config.hf_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_embedding = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@support_torch_compile
class Exaone4_5_MTP(ExaoneMoeMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.quant_config = vllm_config.quant_config
nn.Module.__init__(self)
self.config = config
self.model = Exaone4_5MultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif any(key in name for key in shared_weight_names):
if "embed_tokens" in name:
name = name.replace("language_model.", "")
else:
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))
......@@ -184,11 +184,6 @@ class ExaoneMoeMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"ExaoneMoeMTP currently does not support prefix caching"
)
self.quant_config = vllm_config.quant_config
super().__init__()
......
......@@ -371,6 +371,10 @@ _MULTIMODAL_MODELS = {
"ernie45_vl",
"Ernie4_5_VLMoeForConditionalGeneration",
),
"Exaone4_5_ForConditionalGeneration": (
"exaone4_5",
"Exaone4_5_ForConditionalGeneration",
), # noqa: E501
"FireRedASR2ForConditionalGeneration": (
"fireredasr2",
"FireRedASR2ForConditionalGeneration",
......@@ -569,6 +573,7 @@ _SPECULATIVE_DECODING_MODELS = {
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
"Exaone4_5_MTP": ("exaone4_5_mtp", "Exaone4_5_MTP"),
"NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
......
......@@ -1321,13 +1321,14 @@ class SpecDecodeBaseProposer:
# handle multimodality
assert hasattr(target_model, "config")
if self.get_model_name(target_model) in [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"Exaone4_5_ForConditionalGeneration",
"GlmOcrForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment