Unverified Commit da72dace authored by Lee Yongjun's avatar Lee Yongjun Committed by GitHub
Browse files

[Bugfix] add SupportsMultiModal to Exaone4_5_MTP (#39526)


Signed-off-by: default avatarleeyongjun <jqueen.astro@gmail.com>
parent 8d0aabdd
...@@ -23,8 +23,14 @@ from vllm.model_executor.models.exaone_moe_mtp import ( ...@@ -23,8 +23,14 @@ from vllm.model_executor.models.exaone_moe_mtp import (
ExaoneMoeMultiTokenPredictor, ExaoneMoeMultiTokenPredictor,
) )
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
_require_is_multimodal,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
_merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
...@@ -85,9 +91,12 @@ class Exaone4_5MultiTokenPredictor(ExaoneMoeMultiTokenPredictor): ...@@ -85,9 +91,12 @@ class Exaone4_5MultiTokenPredictor(ExaoneMoeMultiTokenPredictor):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@support_torch_compile @support_torch_compile
class Exaone4_5_MTP(ExaoneMoeMTP): class Exaone4_5_MTP(ExaoneMoeMTP, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -111,6 +120,32 @@ class Exaone4_5_MTP(ExaoneMoeMTP): ...@@ -111,6 +120,32 @@ class Exaone4_5_MTP(ExaoneMoeMTP):
self.unpadded_vocab_size, config.vocab_size self.unpadded_vocab_size, config.vocab_size
) )
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.model.embed_input_ids,
is_multimodal=is_multimodal,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
is_multimodal = _require_is_multimodal(is_multimodal)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"] shared_weight_names = ["embed_tokens", "lm_head"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment