Unverified Commit 63ed2409 authored by Kyungmin Lee's avatar Kyungmin Lee Committed by GitHub
Browse files

Add K-EXAONE-236B-A23B (#31621)


Signed-off-by: default avatarlkm2835 <lkm2835@gmail.com>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarlgai-exaone <exaonemodels@lgresearch.ai>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 95e53d90
...@@ -375,6 +375,7 @@ th { ...@@ -375,6 +375,7 @@ th {
| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ |
| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ |
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ |
| `ExaoneMoeCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B`, etc. | | |
| `Exaone4ForCausalLM` | EXAONE-4 | `LGAI-EXAONE/EXAONE-4.0-32B`, etc. | ✅︎ | ✅︎ | | `Exaone4ForCausalLM` | EXAONE-4 | `LGAI-EXAONE/EXAONE-4.0-32B`, etc. | ✅︎ | ✅︎ |
| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | | `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ |
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ |
......
...@@ -250,6 +250,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -250,6 +250,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True
), ),
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
"ExaoneMoEForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.0.0"
),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
...@@ -1005,6 +1008,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1005,6 +1008,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT", speculative_model="baidu/ERNIE-4.5-21B-A3B-PT",
), ),
"ExaoneMoeMTP": _HfExamplesInfo(
"LGAI-EXAONE/K-EXAONE-236B-A23B",
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.0.0",
),
"Glm4MoeMTPModel": _HfExamplesInfo( "Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5", "zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5",
......
...@@ -33,6 +33,7 @@ MTPModelTypes = Literal[ ...@@ -33,6 +33,7 @@ MTPModelTypes = Literal[
"mimo_mtp", "mimo_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"ernie_mtp", "ernie_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
"longcat_flash_mtp", "longcat_flash_mtp",
"mtp", "mtp",
...@@ -219,6 +220,15 @@ class SpeculativeConfig: ...@@ -219,6 +220,15 @@ class SpeculativeConfig:
hf_config.update( hf_config.update(
{"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]} {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
) )
if hf_config.model_type == "exaone_moe":
hf_config.model_type = "exaone_moe_mtp"
if hf_config.model_type == "exaone_moe_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
)
if hf_config.model_type == "longcat_flash": if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp" hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
......
...@@ -72,6 +72,7 @@ class Exaone4GatedMLP(nn.Module): ...@@ -72,6 +72,7 @@ class Exaone4GatedMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -88,6 +89,7 @@ class Exaone4GatedMLP(nn.Module): ...@@ -88,6 +89,7 @@ class Exaone4GatedMLP(nn.Module):
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only ExaoneMoe MTP model."""
from collections.abc import Iterable
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.exaone_moe import ExaoneMoeDecoderLayer
from vllm.sequence import IntermediateTensors
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
maybe_prefix,
)
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class ExaoneMoeMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config = model_config.hf_config
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = nn.ModuleList(
ExaoneMoeDecoderLayer(
vllm_config.model_config.hf_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
mtp_layer=True,
)
for idx in range(self.num_mtp_layers)
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_embedding = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = spec_step_idx % self.num_mtp_layers
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class ExaoneMoeMTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"ExaoneMoeMTP currently does not support prefix caching"
)
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = ExaoneMoeMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
# padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
**kwargs: object,
):
hidden_states = self.model(
input_ids,
positions,
hidden_states,
intermediate_tensors,
inputs_embeds,
spec_step_idx,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif not any(key in name for key in shared_weight_names):
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))
...@@ -98,6 +98,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -98,6 +98,7 @@ _TEXT_GENERATION_MODELS = {
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
"ExaoneMoEForCausalLM": ("exaone_moe", "ExaoneMoeForCausalLM"),
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
...@@ -457,6 +458,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -457,6 +458,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment