Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
...@@ -43,7 +43,7 @@ def _ref_convert_id_to_token( ...@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind, def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors): dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens) tokens_list=dummy_test_vectors.generation_tokens)
...@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int], num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], num_prompt_logprobs: Optional[int],
dummy_test_vectors): dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
...@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
) # '<|end_of_text|>' ) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops # Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids) suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
...@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool, def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], dummy_test_vectors): num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False) log_stats=False)
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
...@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
def test_iteration_stats(dummy_test_vectors): def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True) log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic() engine_core_timestamp = time.monotonic()
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
...@@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector( ...@@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector(
upper: float, upper: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Create a random vector of top logprob float values. """Create a random vector of top logprob float values.
Use to create fake sample logprobs for testing. Use to create fake sample logprobs for testing.
Note that a real production scenario would require Note that a real production scenario would require
...@@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix( ...@@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix(
upper: float, upper: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Create a random matrix of top logprob float values. """Create a random matrix of top logprob float values.
Use to create fake prompt logprobs for testing. Use to create fake prompt logprobs for testing.
Note that a real production scenario would require Note that a real production scenario would require
...@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors( ...@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors: class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests""" """Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType tokenizer: GeneralTokenizerType
tokenizer_group: TokenizerGroup
vllm_config: EngineArgs vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]] prompt_tokens: list[list[int]]
......
...@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices( ...@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
speculative_config=speculative_config, speculative_config=speculative_config,
) )
tokenizer = llm.get_tokenizer(None) tokenizer = llm.get_tokenizer()
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
tokenizer=tokenizer) tokenizer=tokenizer)
......
...@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest ...@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
try: try:
...@@ -100,8 +100,8 @@ class BenchmarkDataset(ABC): ...@@ -100,8 +100,8 @@ class BenchmarkDataset(ABC):
) -> None: ) -> None:
""" """
Initialize the BenchmarkDataset with an optional dataset path and random Initialize the BenchmarkDataset with an optional dataset path and random
seed. seed.
Args: Args:
dataset_path (Optional[str]): Path to the dataset. If None, it dataset_path (Optional[str]): Path to the dataset. If None, it
indicates that a default or random dataset might be used. indicates that a default or random dataset might be used.
...@@ -133,10 +133,10 @@ class BenchmarkDataset(ABC): ...@@ -133,10 +133,10 @@ class BenchmarkDataset(ABC):
elif isinstance(mm_content, dict): elif isinstance(mm_content, dict):
content.append(mm_content) content.append(mm_content)
else: else:
raise TypeError( raise TypeError(
"Could not process multimodal content of type: " + "Could not process multimodal content of type: " +
f"{type(mm_content)}" f"{type(mm_content)}"
) )
return [{"role": "user", "content": content}] return [{"role": "user", "content": content}]
def load_data(self) -> None: def load_data(self) -> None:
...@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC): ...@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
def get_random_lora_request( def get_random_lora_request(
self, self,
tokenizer: PreTrainedTokenizerBase,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
) -> tuple[Optional[LoRARequest], AnyTokenizer]: ) -> Optional[LoRARequest]:
""" """
Optionally select a random LoRA request and return its associated Optionally select a random LoRA request.
tokenizer.
This method is used when LoRA parameters are provided. It randomly This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for selects a LoRA based on max_loras.
that LoRA if available. Otherwise, it returns the base tokenizer.
Args: Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected.
max_loras (Optional[int]): The maximum number of LoRAs available. max_loras (Optional[int]): The maximum number of LoRAs available.
If `None`, LoRA is not used. If `None`, LoRA is not used.
lora_path (Optional[str]): Path to the LoRA parameters on disk. lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used. If `None`, LoRA is not used.
Returns: Returns:
A tuple with the following elements: A new [LoRARequest][] (or `None` if not applicable).
- A new [LoRARequest][] (or `None` if not applicable).
- The tokenizer associated with the LoRA request
(or the base tokenizer).
""" """
if max_loras is None or lora_path is None: if max_loras is None or lora_path is None:
return None, tokenizer return None
# Generate a random LoRA ID in the range [1, max_loras]. # Generate a random LoRA ID in the range [1, max_loras].
lora_id = random.randint(1, max_loras) lora_id = random.randint(1, max_loras)
...@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC): ...@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
lora_int_id=lora_id, lora_int_id=lora_id,
lora_path=lora_path_on_disk(lora_path), lora_path=lora_path_on_disk(lora_path),
) )
if lora_id not in lora_tokenizer_cache: return lora_request
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod @abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase, def sample(self, tokenizer: PreTrainedTokenizerBase,
...@@ -213,7 +201,7 @@ class BenchmarkDataset(ABC): ...@@ -213,7 +201,7 @@ class BenchmarkDataset(ABC):
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id. request_id_prefix (str) The prefix of request_id.
Returns: Returns:
list[SampleRequest]: A list of sample requests generated from the list[SampleRequest]: A list of sample requests generated from the
...@@ -527,7 +515,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -527,7 +515,7 @@ class RandomDataset(BenchmarkDataset):
size=num_requests) size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1, output_lens = self._rng.integers(output_low, output_high + 1,
size=num_requests) size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size, offsets = self._rng.integers(0, tokenizer.vocab_size,
size=num_requests) size=num_requests)
return input_lens, output_lens, offsets return input_lens, output_lens, offsets
...@@ -555,7 +543,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -555,7 +543,7 @@ class RandomDataset(BenchmarkDataset):
the encoded sequence is truncated before being decoded again. the encoded sequence is truncated before being decoded again.
""" """
# Build the inner sequence by sampling sequentially from the vocab # Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len)) inner_seq = ((offset + index + np.arange(input_len))
% vocab_size).tolist() % vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq token_sequence = prefix_token_ids + inner_seq
...@@ -590,9 +578,9 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -590,9 +578,9 @@ class RandomMultiModalDataset(RandomDataset):
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits. The maximum is further clamped to the sum of per-modality limits.
2) Each item’s modality and shape is sampled from `bucket_config`, a dict 2) Each item’s modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) → probability. We treat mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video. `num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized Entries with zero probability are removed and the rest are renormalized
to sum to 1. to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the When a modality reaches its cap, all of its buckets are excluded and the
...@@ -600,8 +588,8 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -600,8 +588,8 @@ class RandomMultiModalDataset(RandomDataset):
Example bucket configuration: Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket - Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16). (`num_frames`=16).
OBS.: Only image sampling is supported for now. OBS.: Only image sampling is supported for now.
""" """
...@@ -624,9 +612,9 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -624,9 +612,9 @@ class RandomMultiModalDataset(RandomDataset):
def generate_synthetic_image(self, width: int, height: int) -> Image.Image: def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
"""Generate synthetic PIL image with random RGB values. """Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos. (good for stressing I/O), but very unlike real photos.
We could consider a “low-freq” mode (e.g., noise blur) We could consider a “low-freq” mode (e.g., noise blur)
to emulate network realism instead of max stress. to emulate network realism instead of max stress.
""" """
...@@ -638,11 +626,11 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -638,11 +626,11 @@ class RandomMultiModalDataset(RandomDataset):
) )
return Image.fromarray(random_pixels) return Image.fromarray(random_pixels)
def generate_synthetic_video(self, width: int, def generate_synthetic_video(self, width: int,
height: int, height: int,
num_frames: int) -> Any: num_frames: int) -> Any:
"""Generate synthetic video with random values. """Generate synthetic video with random values.
TODO: Finish this method. TODO: Finish this method.
""" """
raise NotImplementedError("Video sampling is WIP.") raise NotImplementedError("Video sampling is WIP.")
...@@ -656,7 +644,7 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -656,7 +644,7 @@ class RandomMultiModalDataset(RandomDataset):
else: else:
raise ValueError(f"Invalid multimodal item configuration: {config}") raise ValueError(f"Invalid multimodal item configuration: {config}")
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
float]) -> dict[tuple[int, int, int], float]: float]) -> dict[tuple[int, int, int], float]:
""" """
Remove zero probability entries Remove zero probability entries
...@@ -676,24 +664,24 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -676,24 +664,24 @@ class RandomMultiModalDataset(RandomDataset):
return {k: v / total for k, v in bucket_config.items()} return {k: v / total for k, v in bucket_config.items()}
def generate_mm_item(self, def generate_mm_item(self,
mm_item_config: tuple[int, int, int], mm_item_config: tuple[int, int, int],
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
""" """
Create synthetic images and videos and Create synthetic images and videos and
apply process_image/process_video respectively. apply process_image/process_video respectively.
This follows the OpenAI API chat completions This follows the OpenAI API chat completions
https://github.com/openai/openai-python https://github.com/openai/openai-python
""" """
if self.map_config_to_modality(mm_item_config) == "image": if self.map_config_to_modality(mm_item_config) == "image":
return process_image(self.generate_synthetic_image( return process_image(self.generate_synthetic_image(
mm_item_config[1], mm_item_config[1],
mm_item_config[0])) mm_item_config[0]))
elif self.map_config_to_modality(mm_item_config) == "video": elif self.map_config_to_modality(mm_item_config) == "video":
return process_video(self.generate_synthetic_video( return process_video(self.generate_synthetic_video(
mm_item_config[1], mm_item_config[1],
mm_item_config[0], mm_item_config[0],
mm_item_config[2])) mm_item_config[2]))
else: else:
raise ValueError(f"Invalid multimodal item configuration: " raise ValueError(f"Invalid multimodal item configuration: "
...@@ -723,17 +711,17 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -723,17 +711,17 @@ class RandomMultiModalDataset(RandomDataset):
f"limit_mm_per_prompt: " f"limit_mm_per_prompt: "
f"{limit_mm_per_prompt.keys()}") f"{limit_mm_per_prompt.keys()}")
# Remove zero probability entries # Remove zero probability entries
# and normalize bucket config to sum to 1 # and normalize bucket config to sum to 1
bucket_config = self.normalize_bucket_config(bucket_config) bucket_config = self.normalize_bucket_config(bucket_config)
logger.info( logger.info(
"Normalized bucket config: %s", bucket_config, "Normalized bucket config: %s", bucket_config,
) )
# Only consider limit per prompt for modalities in bucket config # Only consider limit per prompt for modalities in bucket config
allowed_modalities = {self.map_config_to_modality(cfg) allowed_modalities = {self.map_config_to_modality(cfg)
for cfg in bucket_config} for cfg in bucket_config}
limit_mm_per_prompt = { limit_mm_per_prompt = {
k: v for k, v in limit_mm_per_prompt.items() k: v for k, v in limit_mm_per_prompt.items()
if k in allowed_modalities} if k in allowed_modalities}
if not limit_mm_per_prompt: if not limit_mm_per_prompt:
raise ValueError("No valid limits for modalities present in " raise ValueError("No valid limits for modalities present in "
...@@ -746,19 +734,19 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -746,19 +734,19 @@ class RandomMultiModalDataset(RandomDataset):
# Get max and min num mm items and ensure # Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities # it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items = min( max_num_mm_items = min(
sum(limit_mm_per_prompt.values()), sum(limit_mm_per_prompt.values()),
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
) )
# Ensure min num mm items is at least 0 # Ensure min num mm items is at least 0
min_num_mm_items = max( min_num_mm_items = max(
0, 0,
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
) )
# Raise error if min num mm items is greater than max num mm items # Raise error if min num mm items is greater than max num mm items
if min_num_mm_items > max_num_mm_items: if min_num_mm_items > max_num_mm_items:
raise ValueError(f"Min num mm items is greater than max mm items: " raise ValueError(f"Min num mm items is greater than max mm items: "
f"{min_num_mm_items} > {max_num_mm_items}") f"{min_num_mm_items} > {max_num_mm_items}")
logger.info( logger.info(
"Sampling number of multimodal items from [%s, %s]", "Sampling number of multimodal items from [%s, %s]",
min_num_mm_items, max_num_mm_items, min_num_mm_items, max_num_mm_items,
...@@ -783,8 +771,8 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -783,8 +771,8 @@ class RandomMultiModalDataset(RandomDataset):
whose size is between min_num_mm_items and max_num_mm_items. whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item. Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached. for all modalities is reached.
Note: Note:
...@@ -796,19 +784,19 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -796,19 +784,19 @@ class RandomMultiModalDataset(RandomDataset):
# Get the number of multimodal items to sample # Get the number of multimodal items to sample
request_num_mm_items = int( request_num_mm_items = int(
self._rng.integers(min_num_mm_items, max_num_mm_items + 1) self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
) )
# If request_num_mm_items is 0, yield an empty iterator # If request_num_mm_items is 0, yield an empty iterator
if request_num_mm_items == 0: if request_num_mm_items == 0:
return return
# Initialize modality counters # Initialize modality counters
modality_counter = {self.map_config_to_modality(k): 0 modality_counter = {self.map_config_to_modality(k): 0
for k in bucket_config} for k in bucket_config}
# Copy the bucket config to avoid modifying the original # Copy the bucket config to avoid modifying the original
bucket_config_copy = bucket_config.copy() bucket_config_copy = bucket_config.copy()
# Loop over the number of multimodal items to sample # Loop over the number of multimodal items to sample
while sum(modality_counter.values()) < request_num_mm_items: while sum(modality_counter.values()) < request_num_mm_items:
# Sample a multimodal item config # Sample a multimodal item config
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
p=list(bucket_config_copy.values())) p=list(bucket_config_copy.values()))
modality = self.map_config_to_modality(mm_item_config) modality = self.map_config_to_modality(mm_item_config)
# Check that modality count is less than limit per prompt # Check that modality count is less than limit per prompt
...@@ -849,7 +837,7 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -849,7 +837,7 @@ class RandomMultiModalDataset(RandomDataset):
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
bucket_config: dict[tuple[int, int, int], float] = bucket_config: dict[tuple[int, int, int], float] =
DEFAULT_MM_ITEM_BUCKET_CONFIG, DEFAULT_MM_ITEM_BUCKET_CONFIG,
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
**kwargs, **kwargs,
...@@ -857,7 +845,7 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -857,7 +845,7 @@ class RandomMultiModalDataset(RandomDataset):
# NOTE: Video sampling is WIP. Raise error if video is in bucket config # NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero. # and probability is non-zero.
if any(self.map_config_to_modality(cfg) == "video" and p > 0 if any(self.map_config_to_modality(cfg) == "video" and p > 0
for cfg, p in bucket_config.items()): for cfg, p in bucket_config.items()):
raise NotImplementedError("Video sampling not implemented; " raise NotImplementedError("Video sampling not implemented; "
"set its probability to 0.") "set its probability to 0.")
...@@ -908,7 +896,7 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -908,7 +896,7 @@ class RandomMultiModalDataset(RandomDataset):
]) ])
if enable_multimodal_chat: if enable_multimodal_chat:
# NOTE: For now this option is only provided for completeness # NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it. # given that the serve.py benchmark currently does not use it.
mm_chat_prompt: Any = prompt mm_chat_prompt: Any = prompt
mm_chat_prompt = self.apply_multimodal_chat_transformation( mm_chat_prompt = self.apply_multimodal_chat_transformation(
...@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
entry["conversations"][1]["value"], entry["conversations"][1]["value"],
) )
lora_request, tokenizer = self.get_random_lora_request( lora_request = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids) prompt_len = len(prompt_ids)
...@@ -994,11 +982,11 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -994,11 +982,11 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len skip_min_output_len_check=output_len
is not None): is not None):
continue continue
if image_path := entry.get("image"): if image_path := entry.get("image"):
mm_content = process_image(image_path) mm_content = process_image(image_path)
elif video_path := entry.get("video"): elif video_path := entry.get("video"):
mm_content = process_video(video_path) mm_content = process_video(video_path)
else: else:
mm_content = None mm_content = None
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(
...@@ -1013,9 +1001,9 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -1013,9 +1001,9 @@ class ShareGPTDataset(BenchmarkDataset):
request_id=request_id_prefix + str(ind), request_id=request_id_prefix + str(ind),
)) ))
ind += 1 ind += 1
self.maybe_oversample_requests(samples, self.maybe_oversample_requests(samples,
num_requests, num_requests,
request_id_prefix, request_id_prefix,
no_oversample) no_oversample)
return samples return samples
...@@ -1024,11 +1012,11 @@ class _ValidateDatasetArgs(argparse.Action): ...@@ -1024,11 +1012,11 @@ class _ValidateDatasetArgs(argparse.Action):
"""Argparse action to validate dataset name and path compatibility.""" """Argparse action to validate dataset name and path compatibility."""
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values) setattr(namespace, self.dest, values)
# Get current values of both dataset_name and dataset_path # Get current values of both dataset_name and dataset_path
dataset_name = getattr(namespace, 'dataset_name', 'random') dataset_name = getattr(namespace, 'dataset_name', 'random')
dataset_path = getattr(namespace, 'dataset_path', None) dataset_path = getattr(namespace, 'dataset_path', None)
# Validate the combination # Validate the combination
if dataset_name == "random" and dataset_path is not None: if dataset_name == "random" and dataset_path is not None:
parser.error( parser.error(
...@@ -1053,7 +1041,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1053,7 +1041,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
default="random", default="random",
action=_ValidateDatasetArgs, action=_ValidateDatasetArgs,
choices=[ choices=[
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition", "spec_bench" "custom", "prefix_repetition", "spec_bench"
], ],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
...@@ -1502,7 +1490,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -1502,7 +1490,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"spec_bench": "spec_bench":
lambda: SpecBench(dataset_path=args.dataset_path, lambda: SpecBench(dataset_path=args.dataset_path,
category=args.spec_bench_category).sample( category=args.spec_bench_category).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -1660,7 +1648,7 @@ class CustomDataset(BenchmarkDataset): ...@@ -1660,7 +1648,7 @@ class CustomDataset(BenchmarkDataset):
logger.info("num_requests is set to 0 or negative, " logger.info("num_requests is set to 0 or negative, "
"so using all available samples: %d", "so using all available samples: %d",
num_requests) num_requests)
sampled_requests = [] sampled_requests = []
for i, item in enumerate(self.data): for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -1686,7 +1674,7 @@ class CustomDataset(BenchmarkDataset): ...@@ -1686,7 +1674,7 @@ class CustomDataset(BenchmarkDataset):
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -1700,7 +1688,7 @@ class CustomDataset(BenchmarkDataset): ...@@ -1700,7 +1688,7 @@ class CustomDataset(BenchmarkDataset):
class SpecBench(CustomDataset): class SpecBench(CustomDataset):
""" """
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
Download the dataset using: Download the dataset using:
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
""" # noqa: E501 """ # noqa: E501
...@@ -1736,8 +1724,8 @@ class SpecBench(CustomDataset): ...@@ -1736,8 +1724,8 @@ class SpecBench(CustomDataset):
# leverage CustomDataset sample # leverage CustomDataset sample
kwargs["skip_chat_template"] = False kwargs["skip_chat_template"] = False
return super().sample(**kwargs) return super().sample(**kwargs)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Sonnet Dataset Implementation # Sonnet Dataset Implementation
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
for i in range(num_requests): for i in range(num_requests):
input_len = int(data[i][2]) input_len = int(data[i][2])
output_len = int(data[i][3]) output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request( lora_req = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) max_loras=max_loras, lora_path=lora_path)
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i + # Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size. # j) modulo vocab_size.
...@@ -1995,7 +1983,7 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -1995,7 +1983,7 @@ class ConversationDataset(HuggingFaceDataset):
request_id=request_id_prefix + str(ind), request_id=request_id_prefix + str(ind),
)) ))
ind += 1 ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2055,7 +2043,7 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -2055,7 +2043,7 @@ class VisionArenaDataset(HuggingFaceDataset):
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2172,7 +2160,7 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -2172,7 +2160,7 @@ class InstructCoderDataset(HuggingFaceDataset):
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2234,7 +2222,7 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -2234,7 +2222,7 @@ class MTBenchDataset(HuggingFaceDataset):
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2288,8 +2276,8 @@ class BlazeditDataset(HuggingFaceDataset): ...@@ -2288,8 +2276,8 @@ class BlazeditDataset(HuggingFaceDataset):
# compare the levenshtein distance normalized by code length # compare the levenshtein distance normalized by code length
if norm_distance < min_distance or norm_distance > max_distance: if norm_distance < min_distance or norm_distance > max_distance:
continue continue
# template copied from # template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501 # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file. instruction = f"""Given a code file, please apply the change requests and generate the new file.
...@@ -2322,9 +2310,9 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5 ...@@ -2322,9 +2310,9 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i), request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset): ...@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind), request_id=request_id_prefix + str(ind),
)) ))
ind += 1 ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
...@@ -2470,9 +2457,9 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -2470,9 +2457,9 @@ class NextEditPredictionDataset(HuggingFaceDataset):
)) ))
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, self.maybe_oversample_requests(samples,
num_requests, num_requests,
request_id_prefix, request_id_prefix,
no_oversample) no_oversample)
return samples return samples
...@@ -2562,7 +2549,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -2562,7 +2549,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", " what Whisper supports.",
skipped, skipped,
) )
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2647,7 +2634,7 @@ class MLPerfDataset(HuggingFaceDataset): ...@@ -2647,7 +2634,7 @@ class MLPerfDataset(HuggingFaceDataset):
) )
ind += 1 ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests, self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample) request_id_prefix, no_oversample)
return sampled_requests return sampled_requests
...@@ -2658,7 +2645,7 @@ class MLPerfDataset(HuggingFaceDataset): ...@@ -2658,7 +2645,7 @@ class MLPerfDataset(HuggingFaceDataset):
class PrefixRepetitionRandomDataset(BenchmarkDataset): class PrefixRepetitionRandomDataset(BenchmarkDataset):
# Default values copied from benchmark_serving.py for the repeated prefix # Default values copied from benchmark_serving.py for the repeated prefix
# dataset. # dataset.
DEFAULT_PREFIX_LEN = 256 DEFAULT_PREFIX_LEN = 256
DEFAULT_SUFFIX_LEN = 256 DEFAULT_SUFFIX_LEN = 256
......
...@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self, async def get_tokenizer_async(self) -> AnyTokenizer:
lora_request: Optional[LoRARequest] = None return self.get_tokenizer()
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
async def add_request_async( async def add_request_async(
self, self,
...@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
...@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient): ...@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor: async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor return self.engine.input_preprocessor
async def get_tokenizer( async def get_tokenizer(self) -> AnyTokenizer:
self, return self.engine.get_tokenizer()
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self.engine.get_tokenizer_async(lora_request)
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
......
...@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, ...@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer_group import ( init_tokenizer_from_configs)
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
...@@ -186,7 +185,7 @@ class LLMEngine: ...@@ -186,7 +185,7 @@ class LLMEngine:
return outputs_ return outputs_
tokenizer: Optional[TokenizerGroup] tokenizer: Optional[AnyTokenizer]
def __init__( def __init__(
self, self,
...@@ -233,18 +232,9 @@ class LLMEngine: ...@@ -233,18 +232,9 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None
self.detokenizer = None self.detokenizer = None
tokenizer_group = None
else: else:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = ( self.generation_config_fields = (
...@@ -389,10 +379,8 @@ class LLMEngine: ...@@ -389,10 +379,8 @@ class LLMEngine:
self.detokenizer, self.detokenizer,
self.scheduler, self.scheduler,
self.seq_counter, self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker( stop_checker=StopChecker(
self.scheduler_config.max_model_len, self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
self.reasoner if self.decoding_config.reasoning_backend self.reasoner if self.decoding_config.reasoning_backend
and self.tokenizer else None, and self.tokenizer else None,
), ),
...@@ -521,24 +509,15 @@ class LLMEngine: ...@@ -521,24 +509,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
return self.tokenizer return self.tokenizer
def get_tokenizer( def _init_tokenizer(self) -> AnyTokenizer:
self, return init_tokenizer_from_configs(model_config=self.model_config)
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
lora_config=self.lora_config)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
...@@ -574,11 +553,11 @@ class LLMEngine: ...@@ -574,11 +553,11 @@ class LLMEngine:
) )
return None return None
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs)
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
...@@ -700,7 +679,6 @@ class LLMEngine: ...@@ -700,7 +679,6 @@ class LLMEngine:
processed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
) )
self._add_processed_request( self._add_processed_request(
...@@ -1739,29 +1717,22 @@ class LLMEngine: ...@@ -1739,29 +1717,22 @@ class LLMEngine:
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time) metrics.model_execute_time)
def _validate_model_inputs(self, inputs: ProcessorInputs, def _validate_model_inputs(self, inputs: ProcessorInputs):
lora_request: Optional[LoRARequest]):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None: if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, self._validate_model_input(encoder_inputs, prompt_type="encoder")
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs, self._validate_model_input(decoder_inputs, prompt_type="decoder")
lora_request,
prompt_type="decoder")
def _validate_model_input( def _validate_model_input(
self, self,
prompt_inputs: SingletonInputs, prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*, *,
prompt_type: Literal["encoder", "decoder"], prompt_type: Literal["encoder", "decoder"],
): ):
model_config = self.model_config model_config = self.model_config
tokenizer = (None if self.tokenizer is None else tokenizer = self.tokenizer
self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids = prompt_inputs.get("prompt_token_ids", []) prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids: if not prompt_ids:
...@@ -1822,7 +1793,7 @@ class LLMEngine: ...@@ -1822,7 +1793,7 @@ class LLMEngine:
logits_processors = [] logits_processors = []
if (sampling_params.logit_bias or sampling_params.allowed_token_ids): if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request) tokenizer = self.get_tokenizer()
processors = get_openai_logits_processors( processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias, logit_bias=sampling_params.logit_bias,
...@@ -1835,7 +1806,7 @@ class LLMEngine: ...@@ -1835,7 +1806,7 @@ class LLMEngine:
sampling_params.allowed_token_ids = None sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0: if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request) tokenizer = self.get_tokenizer()
processors = get_bad_words_logits_processors( processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer) bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors) logits_processors.extend(processors)
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import List
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.sequence import SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: List[Scheduler], scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
): ):
"""Create an output processor. """Create an output processor.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker: class StopChecker:
...@@ -20,12 +19,10 @@ class StopChecker: ...@@ -20,12 +19,10 @@ class StopChecker:
def __init__( def __init__(
self, self,
max_model_len: int, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
reasoner: Optional[ReasoningParser] = None, reasoner: Optional[ReasoningParser] = None,
): ):
# Do not use it directly, but use `self._get_max_model_len`. # Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.reasoner = reasoner self.reasoner = reasoner
def _get_max_model_len(self, lora_req: Optional[LoRARequest]): def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
......
...@@ -76,8 +76,7 @@ class EngineClient(ABC): ...@@ -76,8 +76,7 @@ class EngineClient(ABC):
include_stop_str_in_output = params.include_stop_str_in_output include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor() preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group() tokenizer = preprocessor.get_tokenizer()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -260,11 +259,8 @@ class EngineClient(ABC): ...@@ -260,11 +259,8 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def get_tokenizer( async def get_tokenizer(self) -> AnyTokenizer:
self, """Get the tokenizer"""
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
... ...
async def get_io_processor(self) -> IOProcessor: async def get_io_processor(self) -> IOProcessor:
......
...@@ -301,23 +301,17 @@ class LLM: ...@@ -301,23 +301,17 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config, self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin) io_processor_plugin)
def get_tokenizer( def get_tokenizer(self) -> AnyTokenizer:
self, return self.llm_engine.get_tokenizer()
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
lora_request)
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from # compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached' # user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"): if tokenizer.__class__.__name__.startswith("Cached"):
tokenizer_group.tokenizer = tokenizer self.llm_engine.tokenizer = tokenizer
else: else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams: def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None: if self.default_sampling_params is None:
...@@ -707,7 +701,6 @@ class LLM: ...@@ -707,7 +701,6 @@ class LLM:
self, self,
messages: Union[list[ChatCompletionMessageParam], messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]], list[list[ChatCompletionMessageParam]]],
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
...@@ -739,7 +732,7 @@ class LLM: ...@@ -739,7 +732,7 @@ class LLM:
cast(list[ChatCompletionMessageParam], messages) cast(list[ChatCompletionMessageParam], messages)
] ]
tokenizer = self.get_tokenizer(lora_request) tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
...@@ -872,7 +865,6 @@ class LLM: ...@@ -872,7 +865,6 @@ class LLM:
prompts = self.preprocess_chat( prompts = self.preprocess_chat(
messages=messages, messages=messages,
lora_request=lora_request,
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
...@@ -1519,7 +1511,7 @@ class LLM: ...@@ -1519,7 +1511,7 @@ class LLM:
): ):
""" """
Validate that if any multi-modal data is skipped (i.e. None), Validate that if any multi-modal data is skipped (i.e. None),
then its corresponding UUID must be set. then its corresponding UUID must be set.
""" """
if multi_modal_data is None: if multi_modal_data is None:
return return
......
...@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
tool_parser = self.tool_parser tool_parser = self.tool_parser
......
...@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing): ...@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
return None return None
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.tokenizer = await self.engine_client.get_tokenizer()
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
renderer = self._get_renderer(ctx.tokenizer) renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
......
...@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer()
)
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds( engine_prompts = await renderer.render_prompt_and_embeds(
......
...@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request tokenizer = await self.engine_client.get_tokenizer()
)
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
...@@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing): ...@@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing):
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
"""Collect and aggregate batch results """Collect and aggregate batch results
with support for chunked processing. with support for chunked processing.
For chunked requests, performs online aggregation to For chunked requests, performs online aggregation to
minimize memory usage. minimize memory usage.
For regular requests, collects results normally. For regular requests, collects results normally.
""" """
......
...@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = await self.engine_client.get_tokenizer(lora_request tokenizer = await self.engine_client.get_tokenizer()
)
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None: if getattr(request, "dimensions", None) is not None:
......
...@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony: if self.use_harmony:
messages, request_prompts, engine_prompts = ( messages, request_prompts, engine_prompts = (
......
...@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing): ...@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None) None)
......
...@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
...@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(request_id, self._log_inputs(request_id,
request.tokens, request.tokens,
......
...@@ -9,13 +9,11 @@ from typing_extensions import assert_never ...@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict) MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType, EncoderDecoderInputs, ProcessorInputs, PromptType,
...@@ -31,7 +29,7 @@ class InputPreprocessor: ...@@ -31,7 +29,7 @@ class InputPreprocessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup], tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None: ) -> None:
...@@ -42,32 +40,28 @@ class InputPreprocessor: ...@@ -42,32 +40,28 @@ class InputPreprocessor:
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when " raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True") "`skip_tokenizer_init` is True")
return self.tokenizer return self.tokenizer
def get_bos_token_id(self, def get_bos_token_id(self) -> Optional[int]:
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer " logger.warning("Using None for BOS token id because tokenizer "
"is not initialized") "is not initialized")
return None return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id return self.tokenizer.bos_token_id
def get_eos_token_id(self, def get_eos_token_id(self) -> Optional[int]:
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer " logger.warning("Using None for EOS token id because tokenizer "
"is not initialized") "is not initialized")
return None return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]: def get_decoder_start_token_id(self) -> Optional[int]:
""" """
...@@ -190,14 +184,13 @@ class InputPreprocessor: ...@@ -190,14 +184,13 @@ class InputPreprocessor:
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]: ) -> list[int]:
""" """
Apply the model's tokenizer to a text prompt, returning the Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs. corresponding token IDs.
""" """
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config encoder_config = self.model_config.encoder_config
...@@ -205,50 +198,39 @@ class InputPreprocessor: ...@@ -205,50 +198,39 @@ class InputPreprocessor:
if encoder_config and encoder_config.get("do_lower_case", False): if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower() prompt = prompt.lower()
return tokenizer.encode(prompt=prompt, return tokenizer.encode(prompt, **tokenization_kwargs)
lora_request=lora_request,
**tokenization_kwargs)
async def _tokenize_prompt_async( async def _tokenize_prompt_async(
self, self,
prompt: str, prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]: ) -> list[int]:
""" """
Async version of Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt]. [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
""" """
tokenizer = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
return await tokenizer.encode_async(prompt=prompt, return tokenizer.encode(prompt, **tokenization_kwargs)
lora_request=lora_request,
**tokenization_kwargs)
def _get_mm_tokenizer( def _get_mm_tokenizer(self) -> AnyTokenizer:
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
if not self.tokenizer: if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
return tokenizer_group.get_lora_tokenizer(lora_request) return tokenizer
async def _get_mm_tokenizer_async( async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
if not self.tokenizer: if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group() tokenizer = self.get_tokenizer()
return await tokenizer_group.get_lora_tokenizer_async(lora_request) return tokenizer
def _process_multimodal( def _process_multimodal(
self, self,
...@@ -256,7 +238,6 @@ class InputPreprocessor: ...@@ -256,7 +238,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
...@@ -264,7 +245,7 @@ class InputPreprocessor: ...@@ -264,7 +245,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
tokenizer = self._get_mm_tokenizer(lora_request) tokenizer = self._get_mm_tokenizer()
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, self.model_config,
...@@ -299,7 +280,6 @@ class InputPreprocessor: ...@@ -299,7 +280,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
...@@ -307,7 +287,7 @@ class InputPreprocessor: ...@@ -307,7 +287,7 @@ class InputPreprocessor:
Async version of Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal]. [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
""" """
tokenizer = await self._get_mm_tokenizer_async(lora_request) tokenizer = await self._get_mm_tokenizer_async()
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, self.model_config,
...@@ -386,7 +366,6 @@ class InputPreprocessor: ...@@ -386,7 +366,6 @@ class InputPreprocessor:
self, self,
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -400,7 +379,6 @@ class InputPreprocessor: ...@@ -400,7 +379,6 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
...@@ -415,7 +393,6 @@ class InputPreprocessor: ...@@ -415,7 +393,6 @@ class InputPreprocessor:
self, self,
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -429,7 +406,6 @@ class InputPreprocessor: ...@@ -429,7 +406,6 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
...@@ -444,7 +420,6 @@ class InputPreprocessor: ...@@ -444,7 +420,6 @@ class InputPreprocessor:
self, self,
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -457,13 +432,11 @@ class InputPreprocessor: ...@@ -457,13 +432,11 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
inputs = token_inputs( inputs = token_inputs(
...@@ -480,7 +453,6 @@ class InputPreprocessor: ...@@ -480,7 +453,6 @@ class InputPreprocessor:
self, self,
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
...@@ -493,13 +465,11 @@ class InputPreprocessor: ...@@ -493,13 +465,11 @@ class InputPreprocessor:
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
inputs = token_inputs( inputs = token_inputs(
...@@ -516,7 +486,6 @@ class InputPreprocessor: ...@@ -516,7 +486,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -526,7 +495,6 @@ class InputPreprocessor: ...@@ -526,7 +495,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns: Returns:
...@@ -539,21 +507,18 @@ class InputPreprocessor: ...@@ -539,21 +507,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens": if parsed["type"] == "tokens":
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -563,7 +528,6 @@ class InputPreprocessor: ...@@ -563,7 +528,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -578,21 +542,18 @@ class InputPreprocessor: ...@@ -578,21 +542,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens": if parsed["type"] == "tokens":
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -844,7 +805,6 @@ class InputPreprocessor: ...@@ -844,7 +805,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
...@@ -856,7 +816,6 @@ class InputPreprocessor: ...@@ -856,7 +816,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: input prompt * prompt: input prompt
* lora_request
Returns: Returns:
...@@ -866,7 +825,6 @@ class InputPreprocessor: ...@@ -866,7 +825,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -876,7 +834,6 @@ class InputPreprocessor: ...@@ -876,7 +834,6 @@ class InputPreprocessor:
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
...@@ -887,7 +844,6 @@ class InputPreprocessor: ...@@ -887,7 +844,6 @@ class InputPreprocessor:
prompt_comps = await self._prompt_to_llm_inputs_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -897,7 +853,6 @@ class InputPreprocessor: ...@@ -897,7 +853,6 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
...@@ -919,7 +874,6 @@ class InputPreprocessor: ...@@ -919,7 +874,6 @@ class InputPreprocessor:
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
...@@ -927,7 +881,6 @@ class InputPreprocessor: ...@@ -927,7 +881,6 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
...@@ -952,7 +905,6 @@ class InputPreprocessor: ...@@ -952,7 +905,6 @@ class InputPreprocessor:
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
......
...@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, ...@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from .detokenizer_utils import (convert_prompt_ids_to_tokens, from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally) detokenize_incrementally)
from .tokenizer import AnyTokenizer from .tokenizer import AnyTokenizer
from .tokenizer_group import TokenizerGroup
class Detokenizer: class Detokenizer:
"""Provides methods to decode the output of a model into text.""" """Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: TokenizerGroup): def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer_group = tokenizer_group self.tokenizer = tokenizer
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[ prompt_logprobs: list[Optional[dict[
...@@ -32,9 +27,9 @@ class Detokenizer: ...@@ -32,9 +27,9 @@ class Detokenizer:
Args: Args:
seq_group: The sequence group to decode. seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode. prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill). relative to the start of the sequence (for chunked prefill).
Returns: Returns:
The prompt logprobs with the decoded tokens. The prompt logprobs with the decoded tokens.
""" """
...@@ -46,7 +41,6 @@ class Detokenizer: ...@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token. # Only prompt, without the generated token.
all_token_ids = seq.get_token_ids() all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1] prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0 prefix_offset = 0
read_offset = 0 read_offset = 0
next_iter_prefix_offset = 0 next_iter_prefix_offset = 0
...@@ -70,7 +64,7 @@ class Detokenizer: ...@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids[:token_position] + [token_id]) prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset, (new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally( new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer, tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token, all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens, prev_tokens=prev_tokens,
prefix_offset=prefix_offset, prefix_offset=prefix_offset,
...@@ -111,7 +105,6 @@ class Detokenizer: ...@@ -111,7 +105,6 @@ class Detokenizer:
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary. # Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this # Do it here so that we don't have to repeat this
...@@ -119,14 +112,14 @@ class Detokenizer: ...@@ -119,14 +112,14 @@ class Detokenizer:
if seq.tokens is None: if seq.tokens is None:
(seq.tokens, seq.prefix_offset, (seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens( seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer, tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1], prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens, skip_special_tokens=prms.skip_special_tokens,
) )
(new_tokens, new_decoded_token_text, prefix_offset, (new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
tokenizer=tokenizer, tokenizer=self.tokenizer,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
...@@ -150,7 +143,7 @@ class Detokenizer: ...@@ -150,7 +143,7 @@ class Detokenizer:
and token_id != VLLM_INVALID_TOKEN_ID): and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id] all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally( (_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer, tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob, all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment