Unverified Commit e1a5c2f0 authored by Aurick Qiao's avatar Aurick Qiao Committed by GitHub
Browse files

[Model] Whisper model implementation (#11280)


Co-authored-by: default avatarAurick Qiao <aurick.qiao@snowflake.com>
parent fd3a62a1
...@@ -363,12 +363,14 @@ steps: ...@@ -363,12 +363,14 @@ steps:
- tests/models/decoder_only/audio_language - tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language - tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language - tests/models/embedding/vision_language
- tests/models/encoder_decoder/audio_language
- tests/models/encoder_decoder/vision_language - tests/models/encoder_decoder/vision_language
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model
......
import time
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
# Create a Whisper encoder/decoder model instance
llm = LLM(
model="openai/whisper-large-v3",
max_model_len=448,
max_num_seqs=400,
limit_mm_per_prompt={"audio": 1},
kv_cache_dtype="fp8",
)
prompts = [
{
"prompt": "<|startoftranscript|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt": "<|startoftranscript|>",
}
] * 1024
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)
start = time.time()
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
duration = time.time() - start
print("Duration:", duration)
print("RPS:", len(prompts) / duration)
"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling.
Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`.
"""
from typing import Optional
import pytest
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from ....utils import fork_new_process_for_each_test, multi_gpu_test
PROMPTS = [
{
"prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
}
]
EXPECTED = {
"openai/whisper-tiny": [
" He has birth words I spoke in the original corner of that. And a"
" little piece of black coat poetry. Mary had a little sandwich,"
" sweet, with white and snow. And everyone had it very went the last"
" would sure to go.",
" >> And the old one, fit John the way to Edgar Martinez. >> One more"
" to line down the field line for our base camp. Here comes joy. Here"
" is June and the third base. They're going to wave him in. The throw"
" to the plate will be late. The Mariners are going to play for the"
" American League Championship. I don't believe it. It just continues"
" by all five."
],
"openai/whisper-small": [
" The first words I spoke in the original pornograph. A little piece"
" of practical poetry. Mary had a little lamb, its fleece was quite a"
" slow, and everywhere that Mary went the lamb was sure to go.",
" And the old one pitch on the way to Edgar Martinez one month. Here"
" comes joy. Here is Junior to third base. They're gonna wave him"
" in. The throw to the plate will be late. The Mariners are going to"
" play for the American League Championship. I don't believe it. It"
" just continues. My, oh my."
],
"openai/whisper-medium": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its fleece was quite as"
" slow, and everywhere that Mary went the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez swung on the line"
" down the left field line for Obeyshev. Here comes Joy. Here is"
" Jorgen at third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh"
" my."
],
"openai/whisper-large-v3": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its feet were quite as"
" slow, and everywhere that Mary went, the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line."
" Now the left field line for a base hit. Here comes Joy. Here is"
" Junior to third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh,"
" my."
],
"openai/whisper-large-v3-turbo": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its streets were quite"
" as slow, and everywhere that Mary went the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line"
" down the left field line for a base hit. Here comes Joy. Here is"
" Junior to third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh,"
" my."
]
}
def run_test(
model: str,
*,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
prompt_list = PROMPTS * 10
expected_list = EXPECTED[model] * 10
llm = LLM(
model=model,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)
outputs = llm.generate(prompt_list, sampling_params)
for output, expected in zip(outputs, expected_list):
print(output.outputs[0].text)
assert output.outputs[0].text == expected
@fork_new_process_for_each_test
@pytest.mark.core_model
@pytest.mark.parametrize(
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
def test_models(model) -> None:
run_test(model, tensor_parallel_size=1)
@multi_gpu_test(num_gpus=2)
@pytest.mark.core_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
def test_models_distributed(model, distributed_executor_backend) -> None:
run_test(model,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend)
...@@ -204,6 +204,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -204,6 +204,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"), "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
# [Encoder-decoder] # [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
} }
_SPECULATIVE_DECODING_EXAMPLE_MODELS = { _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
......
...@@ -2312,6 +2312,8 @@ def _get_and_verify_max_len( ...@@ -2312,6 +2312,8 @@ def _get_and_verify_max_len(
"seq_length", "seq_length",
# Command-R # Command-R
"model_max_length", "model_max_length",
# Whisper
"max_target_positions",
# Others # Others
"max_sequence_length", "max_sequence_length",
"max_seq_length", "max_seq_length",
......
...@@ -184,10 +184,16 @@ class InputPreprocessor: ...@@ -184,10 +184,16 @@ class InputPreprocessor:
corresponding token IDs. corresponding token IDs.
""" """
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer_group()
add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return tokenizer.encode(request_id=request_id, return tokenizer.encode(request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request) lora_request=lora_request,
add_special_tokens=add_special_tokens)
async def _tokenize_prompt_async( async def _tokenize_prompt_async(
self, self,
...@@ -197,10 +203,17 @@ class InputPreprocessor: ...@@ -197,10 +203,17 @@ class InputPreprocessor:
) -> List[int]: ) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer_group()
add_special_tokens = None
return await tokenizer.encode_async(request_id=request_id, if self.model_config.hf_config.model_type == "whisper":
prompt=prompt, # For Whisper, special tokens should be provided by the user based
lora_request=lora_request) # on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
def _can_process_multimodal(self) -> bool: def _can_process_multimodal(self) -> bool:
model_config = self.model_config model_config = self.model_config
...@@ -439,8 +452,15 @@ class InputPreprocessor: ...@@ -439,8 +452,15 @@ class InputPreprocessor:
assert_never(encoder_inputs) # type: ignore[arg-type] assert_never(encoder_inputs) # type: ignore[arg-type]
if decoder_inputs is None: if decoder_inputs is None:
dec_token_ids = self._prepare_decoder_input_ids_for_generation( if self.model_config.hf_config.model_type == "whisper":
None) # For Whisper models, the text prompt should go to the decoder.
# If no explicit encoder/decoder inputs, then copy the prompt
# from the encoder to the decoder. The encoder tokens are later
# overridden by the audio features.
dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
else:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
decoder_inputs = token_inputs(dec_token_ids) decoder_inputs = token_inputs(dec_token_ids)
elif (decoder_inputs["type"] == "token" elif (decoder_inputs["type"] == "token"
or decoder_inputs["type"] == "multimodal"): or decoder_inputs["type"] == "multimodal"):
......
...@@ -170,6 +170,7 @@ _MULTIMODAL_MODELS = { ...@@ -170,6 +170,7 @@ _MULTIMODAL_MODELS = {
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder] # [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
} }
_SPECULATIVE_DECODING_MODELS = { _SPECULATIVE_DECODING_MODELS = {
......
This diff is collapsed.
...@@ -16,7 +16,7 @@ from transformers import BatchFeature, ProcessorMixin ...@@ -16,7 +16,7 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig, from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -57,24 +57,6 @@ class PromptReplacement: ...@@ -57,24 +57,6 @@ class PromptReplacement:
) )
def _encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_encode( def _cached_encode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
...@@ -82,7 +64,9 @@ def _cached_encode( ...@@ -82,7 +64,9 @@ def _cached_encode(
*, *,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> list[int]: ) -> list[int]:
return _encode(tokenizer, text, add_special_tokens=add_special_tokens) return encode_tokens(tokenizer,
text,
add_special_tokens=add_special_tokens)
def _decode( def _decode(
...@@ -983,7 +967,9 @@ class BaseMultiModalProcessor(ABC): ...@@ -983,7 +967,9 @@ class BaseMultiModalProcessor(ABC):
mm_item_counts, mm_item_counts,
) )
token_ids = _encode(tokenizer, text) token_ids = encode_tokens(tokenizer,
text,
add_special_tokens=False)
matched_repls = [match.prompt_repl for match in text_matches] matched_repls = [match.prompt_repl for match in text_matches]
placeholders = self._find_placeholders(matched_repls, token_ids, placeholders = self._find_placeholders(matched_repls, token_ids,
......
...@@ -710,15 +710,27 @@ class SequenceGroup: ...@@ -710,15 +710,27 @@ class SequenceGroup:
@property @property
def multi_modal_data(self) -> MultiModalDataDict: def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data if self.first_seq.multi_modal_data:
return self.first_seq.multi_modal_data
elif self.encoder_seq is not None:
return self.encoder_seq.multi_modal_data
return {}
@property @property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.first_seq.multi_modal_placeholders if self.first_seq.multi_modal_data:
return self.first_seq.multi_modal_placeholders
elif self.encoder_seq is not None:
return self.encoder_seq.multi_modal_placeholders
return {}
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.first_seq.mm_processor_kwargs if self.first_seq.multi_modal_data:
return self.first_seq.mm_processor_kwargs
elif self.encoder_seq is not None:
return self.encoder_seq.mm_processor_kwargs
return {}
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
......
...@@ -21,6 +21,25 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ...@@ -21,6 +21,25 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer] MistralTokenizer]
def encode_tokens(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
elif add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties. """Get tokenizer with cached properties.
......
...@@ -32,7 +32,8 @@ class BaseTokenizerGroup(ABC): ...@@ -32,7 +32,8 @@ class BaseTokenizerGroup(ABC):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
pass pass
...@@ -41,7 +42,8 @@ class BaseTokenizerGroup(ABC): ...@@ -41,7 +42,8 @@ class BaseTokenizerGroup(ABC):
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
pass pass
......
...@@ -112,7 +112,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -112,7 +112,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt. We pick an idle actor and use it to encode the prompt.
...@@ -132,7 +133,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -132,7 +133,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request)) lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it. # If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.", logger.warning("%s died with ActorDiedError, reinitializing.",
...@@ -143,7 +145,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -143,7 +145,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request)) lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
logger.error( logger.error(
"%s died for second time in a row, marking " "%s died for second time in a row, marking "
...@@ -160,7 +163,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -160,7 +163,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt. We pick an idle actor and use it to encode the prompt.
...@@ -177,9 +181,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -177,9 +181,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor_is_alive = True actor_is_alive = True
original_actor = actor original_actor = actor
try: try:
ret = await actor.encode.remote(request_id=request_id, ret = await actor.encode.remote(
prompt=prompt, request_id=request_id,
lora_request=lora_request) prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e: except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it. # If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.", logger.warning("%s died with ActorDiedError, reinitializing.",
...@@ -187,9 +193,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -187,9 +193,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
exc_info=e) exc_info=e)
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = await actor.encode.remote(request_id=request_id, ret = await actor.encode.remote(
prompt=prompt, request_id=request_id,
lora_request=lora_request) prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e: except ActorDiedError as e:
logger.error( logger.error(
"%s died for second time in a row, marking " "%s died for second time in a row, marking "
......
...@@ -2,7 +2,7 @@ from typing import List, Optional ...@@ -2,7 +2,7 @@ from typing import List, Optional
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer, get_lora_tokenizer,
get_lora_tokenizer_async, get_lora_tokenizer_async,
get_tokenizer) get_tokenizer)
...@@ -55,9 +55,12 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -55,9 +55,12 @@ class TokenizerGroup(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request) tokenizer = self.get_lora_tokenizer(lora_request)
ret = tokenizer.encode(prompt) ret = encode_tokens(tokenizer,
prompt,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request) self._raise_if_input_too_long(ret, lora_request)
return ret return ret
...@@ -65,9 +68,12 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -65,9 +68,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request) tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = tokenizer.encode(prompt) ret = encode_tokens(tokenizer,
prompt,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request) self._raise_if_input_too_long(ret, lora_request)
return ret return ret
......
...@@ -287,12 +287,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -287,12 +287,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seq_len, seq_len,
self.mm_registry, self.mm_registry,
is_encoder_data=False) is_encoder_data=False)
encoder_dummy_data \ encoder_dummy_data = self.input_registry \
= self.input_registry.dummy_data_for_profiling( .dummy_data_for_profiling(self.model_config,
self.model_config, seq_len,
seq_len, self.mm_registry,
self.mm_registry, is_encoder_data=True)
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
assert len( assert len(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment