Unverified Commit ec17bdd8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (1.5/2) (#34598)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb59c902
...@@ -93,14 +93,14 @@ def _build_renderer( ...@@ -93,14 +93,14 @@ def _build_renderer(
def _preprocess_prompt( def _preprocess_prompt(
mdoel_config: ModelConfig, model_config: ModelConfig,
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes], prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
): ):
return [ return [
( (
prompt prompt
if isinstance(prompt, bytes) if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt) else parse_model_prompt(model_config, prompt)
) )
for prompt in prompt_to_seq(prompt_or_prompts) for prompt in prompt_to_seq(prompt_or_prompts)
] ]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Literal, TypeAlias from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch import torch
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict, assert_never
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
...@@ -200,15 +200,22 @@ class TokenInputs(_InputOptions): ...@@ -200,15 +200,22 @@ class TokenInputs(_InputOptions):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
def token_inputs( def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
*,
prompt: str | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values.""" values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None: if cache_salt is not None:
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -224,15 +231,22 @@ class EmbedsInputs(_InputOptions): ...@@ -224,15 +231,22 @@ class EmbedsInputs(_InputOptions):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
def embeds_inputs( def embeds_inputs(
prompt_embeds: torch.Tensor, prompt_embeds: torch.Tensor,
*,
prompt: str | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
) -> EmbedsInputs: ) -> EmbedsInputs:
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values.""" values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None: if cache_salt is not None:
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -278,10 +292,12 @@ class EncoderDecoderInputs(TypedDict): ...@@ -278,10 +292,12 @@ class EncoderDecoderInputs(TypedDict):
for encoder-decoder models. for encoder-decoder models.
""" """
encoder: EncoderInputs type: Literal["enc_dec"]
encoder_prompt: EncoderInputs
"""The inputs for the encoder portion.""" """The inputs for the encoder portion."""
decoder: DecoderInputs decoder_prompt: DecoderInputs
"""The inputs for the decoder portion.""" """The inputs for the decoder portion."""
...@@ -296,3 +312,94 @@ which can be passed to ...@@ -296,3 +312,94 @@ which can be passed to
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
"""The inputs for a single encoder/decoder prompt.""" """The inputs for a single encoder/decoder prompt."""
def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
raise RuntimeError(
"You should register an encoder-decoder multi-modal processor "
"for encoder-decoder models."
)
return inputs # type: ignore[return-value]
def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
return inputs
def _prepare_decoder_input_ids_for_generation(
decoder_input_ids: list[int],
decoder_start_token_id: int,
) -> list[int]:
"""
Prepare `decoder_input_ids` for generation with encoder-decoder models,
according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
Source:
https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py
"""
if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id:
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def build_enc_dec_inputs(
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None,
decoder_start_token_id: int,
) -> EncoderDecoderInputs:
enc_inputs = _validate_enc_inputs(encoder_inputs)
if decoder_inputs is None:
dec_inputs: DecoderInputs = enc_inputs
else:
dec_inputs = _validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal":
from vllm.multimodal.inputs import mm_inputs
enc_inputs_new = token_inputs(
enc_inputs["encoder_prompt_token_ids"],
prompt=enc_inputs.get("encoder_prompt"),
)
dec_inputs_new = mm_inputs(
prompt_token_ids=dec_inputs["prompt_token_ids"],
prompt=dec_inputs.get("prompt"),
mm_kwargs=enc_inputs["mm_kwargs"],
mm_hashes=enc_inputs["mm_hashes"],
mm_placeholders=enc_inputs["mm_placeholders"],
)
elif enc_inputs["type"] == "token":
enc_inputs_new = token_inputs(prompt_token_ids=[])
dec_inputs_new = dec_inputs
else:
assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"],
decoder_start_token_id,
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt
return EncoderDecoderInputs(
type="enc_dec",
encoder_prompt=enc_inputs_new,
decoder_prompt=dec_inputs_new,
)
...@@ -7,11 +7,7 @@ from .data import ProcessorInputs, SingletonInputs ...@@ -7,11 +7,7 @@ from .data import ProcessorInputs, SingletonInputs
def split_enc_dec_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs, inputs: ProcessorInputs,
) -> tuple[SingletonInputs | None, SingletonInputs]: ) -> tuple[SingletonInputs | None, SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs: if inputs["type"] == "enc_dec":
# NOTE: This passes pyright but not mypy return inputs["encoder_prompt"], inputs["decoder_prompt"]
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs return None, inputs
...@@ -7,6 +7,7 @@ from typing import Any, overload ...@@ -7,6 +7,7 @@ from typing import Any, overload
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs.data import build_enc_dec_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
...@@ -67,54 +68,6 @@ class InputPreprocessor: ...@@ -67,54 +68,6 @@ class InputPreprocessor:
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.renderer.get_tokenizer() return self.renderer.get_tokenizer()
def get_decoder_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder
model. Raises an error if it is not available.
"""
dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None
)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.renderer.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
return dec_start_token_id
def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on:
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
specifically,
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self.get_decoder_start_token_id()
if (
len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id
):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
...@@ -332,66 +285,6 @@ class InputPreprocessor: ...@@ -332,66 +285,6 @@ class InputPreprocessor:
assert_never(prompt) # type: ignore[arg-type] assert_never(prompt) # type: ignore[arg-type]
def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
raise RuntimeError(
"You should register an encoder-decoder "
"multi-modal processor for encoder-decoder models."
)
return inputs # type: ignore[return-value]
def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
return inputs
def _build_enc_dec_inputs(
self,
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None = None,
) -> EncoderDecoderInputs:
enc_inputs = self._validate_enc_inputs(encoder_inputs)
if decoder_inputs is None:
dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
else:
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal":
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
dec_inputs_new = MultiModalInputs(
type="multimodal",
prompt_token_ids=dec_inputs["prompt_token_ids"],
mm_kwargs=enc_inputs["mm_kwargs"],
mm_hashes=enc_inputs["mm_hashes"],
mm_placeholders=enc_inputs["mm_placeholders"],
)
elif enc_inputs["type"] == "token":
enc_inputs_new = token_inputs(prompt_token_ids=[])
dec_inputs_new = dec_inputs
else:
assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
dec_inputs_new["prompt_token_ids"]
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt
return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: EncoderDecoderDictPrompt, prompt: EncoderDecoderDictPrompt,
...@@ -417,7 +310,7 @@ class InputPreprocessor: ...@@ -417,7 +310,7 @@ class InputPreprocessor:
encoder_prompt = prompt["encoder_prompt"] encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"] decoder_prompt = prompt["decoder_prompt"]
return self._build_enc_dec_inputs( return build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs( encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt, encoder_prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
...@@ -431,6 +324,7 @@ class InputPreprocessor: ...@@ -431,6 +324,7 @@ class InputPreprocessor:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
), ),
decoder_start_token_id=self.renderer.get_dec_start_token_id(),
) )
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
......
...@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import ( ...@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict, MultiModalUUIDDict,
mm_inputs,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
ImageEmbeddingItems, ImageEmbeddingItems,
...@@ -837,8 +838,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -837,8 +838,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
for modality, placeholders in mm_placeholders.items() for modality, placeholders in mm_placeholders.items()
} }
return MultiModalInputs( return mm_inputs(
type="multimodal",
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
......
...@@ -48,6 +48,7 @@ from vllm.multimodal.inputs import ( ...@@ -48,6 +48,7 @@ from vllm.multimodal.inputs import (
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict, MultiModalUUIDDict,
PlaceholderRange, PlaceholderRange,
mm_inputs,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
DictEmbeddingItems, DictEmbeddingItems,
...@@ -222,8 +223,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing ...@@ -222,8 +223,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
), ),
) )
return MultiModalInputs( return mm_inputs(
type="multimodal",
prompt_token_ids=[1], prompt_token_ids=[1],
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
......
...@@ -33,6 +33,7 @@ from vllm.multimodal.inputs import ( ...@@ -33,6 +33,7 @@ from vllm.multimodal.inputs import (
MultiModalInputs, MultiModalInputs,
MultiModalUUIDDict, MultiModalUUIDDict,
PlaceholderRange, PlaceholderRange,
mm_inputs,
) )
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
...@@ -260,8 +261,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -260,8 +261,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
) )
return MultiModalInputs( return mm_inputs(
type="multimodal",
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
......
...@@ -20,7 +20,7 @@ from typing import ( ...@@ -20,7 +20,7 @@ from typing import (
import numpy as np import numpy as np
from PIL.Image import Image from PIL.Image import Image
from typing_extensions import TypeVar from typing_extensions import NotRequired, TypeVar
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
...@@ -1075,6 +1075,9 @@ class MultiModalInputs(_InputOptions): ...@@ -1075,6 +1075,9 @@ class MultiModalInputs(_InputOptions):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens.""" """The processed token IDs which includes placeholder tokens."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
mm_kwargs: MultiModalKwargsOptionalItems mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
...@@ -1088,6 +1091,31 @@ class MultiModalInputs(_InputOptions): ...@@ -1088,6 +1091,31 @@ class MultiModalInputs(_InputOptions):
""" """
def mm_inputs(
prompt_token_ids: list[int],
mm_kwargs: MultiModalKwargsOptionalItems,
mm_hashes: MultiModalHashes,
mm_placeholders: MultiModalPlaceholderDict,
*,
prompt: str | None = None,
cache_salt: str | None = None,
) -> MultiModalInputs:
inputs = MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_token_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
class MultiModalEncDecInputs(MultiModalInputs): class MultiModalEncDecInputs(MultiModalInputs):
""" """
Represents the outputs of Represents the outputs of
...@@ -1101,3 +1129,31 @@ class MultiModalEncDecInputs(MultiModalInputs): ...@@ -1101,3 +1129,31 @@ class MultiModalEncDecInputs(MultiModalInputs):
encoder_prompt_token_ids: list[int] encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt.""" """The processed token IDs of the encoder prompt."""
encoder_prompt: NotRequired[str]
"""The prompt text corresponding to the encoder token IDs, if available."""
def mm_enc_dec_inputs(
encoder_inputs: MultiModalInputs,
decoder_prompt_token_ids: list[int],
*,
decoder_prompt: str | None = None,
) -> MultiModalEncDecInputs:
inputs = MultiModalEncDecInputs(
type="multimodal",
prompt_token_ids=decoder_prompt_token_ids,
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
mm_kwargs=encoder_inputs["mm_kwargs"],
mm_hashes=encoder_inputs["mm_hashes"],
mm_placeholders=encoder_inputs["mm_placeholders"],
)
if decoder_prompt is not None:
inputs["prompt"] = decoder_prompt
if "prompt" in encoder_inputs:
inputs["encoder_prompt"] = encoder_inputs["prompt"]
if "cache_salt" in encoder_inputs:
inputs["cache_salt"] = encoder_inputs["cache_salt"]
return inputs
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
...@@ -26,7 +26,7 @@ class MediaWithBytes(Generic[_T]): ...@@ -26,7 +26,7 @@ class MediaWithBytes(Generic[_T]):
""" """
media: _T media: _T
original_bytes: bytes original_bytes: bytes = field(repr=False)
def __array__(self, *args, **kwargs) -> np.ndarray: def __array__(self, *args, **kwargs) -> np.ndarray:
"""Allow np.array(obj) to return np.array(obj.media).""" """Allow np.array(obj) to return np.array(obj.media)."""
......
...@@ -34,6 +34,8 @@ from ..inputs import ( ...@@ -34,6 +34,8 @@ from ..inputs import (
MultiModalKwargsOptionalItems, MultiModalKwargsOptionalItems,
MultiModalUUIDDict, MultiModalUUIDDict,
PlaceholderRange, PlaceholderRange,
mm_enc_dec_inputs,
mm_inputs,
) )
from ..parse import ( from ..parse import (
DictEmbeddingItems, DictEmbeddingItems,
...@@ -1803,8 +1805,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1803,8 +1805,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, placeholders in mm_placeholders.items() for modality, placeholders in mm_placeholders.items()
} }
return MultiModalInputs( return mm_inputs(
type="multimodal",
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_info.kwargs, mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes, mm_hashes=mm_info.hashes,
...@@ -1848,12 +1849,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1848,12 +1849,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
else: else:
decoder_prompt_ids = decoder_prompt_raw decoder_prompt_ids = decoder_prompt_raw
mm_inputs = MultiModalEncDecInputs( return mm_enc_dec_inputs(
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], encoder_inputs,
**encoder_inputs, decoder_prompt_ids,
) )
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
return mm_inputs
def apply( def apply(
self, self,
......
...@@ -153,6 +153,27 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -153,6 +153,27 @@ class BaseRenderer(ABC, Generic[_T]):
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
def get_dec_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder model,
raising an error if it is not available.
"""
dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None
)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
return dec_start_token_id
@cached_property @cached_property
def default_cmpl_tok_params(self) -> TokenizeParams: def default_cmpl_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor mm_processor = self.mm_processor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment