Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
......@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens)
......@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
......@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
......@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
......@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
......
......@@ -9,7 +9,6 @@ import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
......@@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector(
upper: float,
) -> torch.Tensor:
"""Create a random vector of top logprob float values.
Use to create fake sample logprobs for testing.
Note that a real production scenario would require
......@@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix(
upper: float,
) -> torch.Tensor:
"""Create a random matrix of top logprob float values.
Use to create fake prompt logprobs for testing.
Note that a real production scenario would require
......@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType
tokenizer_group: TokenizerGroup
vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens
prompt_tokens: list[list[int]]
......
......@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser=reasoning_parser,
speculative_config=speculative_config,
)
tokenizer = llm.get_tokenizer(None)
tokenizer = llm.get_tokenizer()
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
tokenizer=tokenizer)
......
......@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import PlaceholderModule
try:
......@@ -100,8 +100,8 @@ class BenchmarkDataset(ABC):
) -> None:
"""
Initialize the BenchmarkDataset with an optional dataset path and random
seed.
seed.
Args:
dataset_path (Optional[str]): Path to the dataset. If None, it
indicates that a default or random dataset might be used.
......@@ -133,10 +133,10 @@ class BenchmarkDataset(ABC):
elif isinstance(mm_content, dict):
content.append(mm_content)
else:
raise TypeError(
raise TypeError(
"Could not process multimodal content of type: " +
f"{type(mm_content)}"
)
f"{type(mm_content)}"
)
return [{"role": "user", "content": content}]
def load_data(self) -> None:
......@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
def get_random_lora_request(
self,
tokenizer: PreTrainedTokenizerBase,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
) -> Optional[LoRARequest]:
"""
Optionally select a random LoRA request and return its associated
tokenizer.
Optionally select a random LoRA request.
This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for
that LoRA if available. Otherwise, it returns the base tokenizer.
selects a LoRA based on max_loras.
Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected.
max_loras (Optional[int]): The maximum number of LoRAs available.
If `None`, LoRA is not used.
lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used.
Returns:
A tuple with the following elements:
- A new [LoRARequest][] (or `None` if not applicable).
- The tokenizer associated with the LoRA request
(or the base tokenizer).
A new [LoRARequest][] (or `None` if not applicable).
"""
if max_loras is None or lora_path is None:
return None, tokenizer
return None
# Generate a random LoRA ID in the range [1, max_loras].
lora_id = random.randint(1, max_loras)
......@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
lora_int_id=lora_id,
lora_path=lora_path_on_disk(lora_path),
)
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
return lora_request
@abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase,
......@@ -213,7 +201,7 @@ class BenchmarkDataset(ABC):
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns:
list[SampleRequest]: A list of sample requests generated from the
......@@ -527,7 +515,7 @@ class RandomDataset(BenchmarkDataset):
size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1,
size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size,
offsets = self._rng.integers(0, tokenizer.vocab_size,
size=num_requests)
return input_lens, output_lens, offsets
......@@ -555,7 +543,7 @@ class RandomDataset(BenchmarkDataset):
the encoded sequence is truncated before being decoded again.
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len))
inner_seq = ((offset + index + np.arange(input_len))
% vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
......@@ -590,9 +578,9 @@ class RandomMultiModalDataset(RandomDataset):
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits.
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the
......@@ -600,8 +588,8 @@ class RandomMultiModalDataset(RandomDataset):
Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
OBS.: Only image sampling is supported for now.
"""
......@@ -624,9 +612,9 @@ class RandomMultiModalDataset(RandomDataset):
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
"""Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
We could consider a “low-freq” mode (e.g., noise blur)
to emulate network realism instead of max stress.
"""
......@@ -638,11 +626,11 @@ class RandomMultiModalDataset(RandomDataset):
)
return Image.fromarray(random_pixels)
def generate_synthetic_video(self, width: int,
height: int,
def generate_synthetic_video(self, width: int,
height: int,
num_frames: int) -> Any:
"""Generate synthetic video with random values.
TODO: Finish this method.
"""
raise NotImplementedError("Video sampling is WIP.")
......@@ -656,7 +644,7 @@ class RandomMultiModalDataset(RandomDataset):
else:
raise ValueError(f"Invalid multimodal item configuration: {config}")
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
float]) -> dict[tuple[int, int, int], float]:
"""
Remove zero probability entries
......@@ -676,24 +664,24 @@ class RandomMultiModalDataset(RandomDataset):
return {k: v / total for k, v in bucket_config.items()}
def generate_mm_item(self,
def generate_mm_item(self,
mm_item_config: tuple[int, int, int],
) -> Mapping[str, Any]:
"""
Create synthetic images and videos and
Create synthetic images and videos and
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
"""
if self.map_config_to_modality(mm_item_config) == "image":
return process_image(self.generate_synthetic_image(
mm_item_config[1],
mm_item_config[0]))
elif self.map_config_to_modality(mm_item_config) == "video":
return process_video(self.generate_synthetic_video(
mm_item_config[1],
mm_item_config[0],
mm_item_config[1],
mm_item_config[0],
mm_item_config[2]))
else:
raise ValueError(f"Invalid multimodal item configuration: "
......@@ -723,17 +711,17 @@ class RandomMultiModalDataset(RandomDataset):
f"limit_mm_per_prompt: "
f"{limit_mm_per_prompt.keys()}")
# Remove zero probability entries
# Remove zero probability entries
# and normalize bucket config to sum to 1
bucket_config = self.normalize_bucket_config(bucket_config)
logger.info(
"Normalized bucket config: %s", bucket_config,
)
# Only consider limit per prompt for modalities in bucket config
allowed_modalities = {self.map_config_to_modality(cfg)
allowed_modalities = {self.map_config_to_modality(cfg)
for cfg in bucket_config}
limit_mm_per_prompt = {
k: v for k, v in limit_mm_per_prompt.items()
k: v for k, v in limit_mm_per_prompt.items()
if k in allowed_modalities}
if not limit_mm_per_prompt:
raise ValueError("No valid limits for modalities present in "
......@@ -746,19 +734,19 @@ class RandomMultiModalDataset(RandomDataset):
# Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items = min(
sum(limit_mm_per_prompt.values()),
sum(limit_mm_per_prompt.values()),
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
)
# Ensure min num mm items is at least 0
min_num_mm_items = max(
0,
0,
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
)
# Raise error if min num mm items is greater than max num mm items
if min_num_mm_items > max_num_mm_items:
raise ValueError(f"Min num mm items is greater than max mm items: "
f"{min_num_mm_items} > {max_num_mm_items}")
logger.info(
"Sampling number of multimodal items from [%s, %s]",
min_num_mm_items, max_num_mm_items,
......@@ -783,8 +771,8 @@ class RandomMultiModalDataset(RandomDataset):
whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached.
Note:
......@@ -796,19 +784,19 @@ class RandomMultiModalDataset(RandomDataset):
# Get the number of multimodal items to sample
request_num_mm_items = int(
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
)
)
# If request_num_mm_items is 0, yield an empty iterator
if request_num_mm_items == 0:
return
# Initialize modality counters
modality_counter = {self.map_config_to_modality(k): 0
modality_counter = {self.map_config_to_modality(k): 0
for k in bucket_config}
# Copy the bucket config to avoid modifying the original
bucket_config_copy = bucket_config.copy()
# Loop over the number of multimodal items to sample
while sum(modality_counter.values()) < request_num_mm_items:
# Sample a multimodal item config
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
p=list(bucket_config_copy.values()))
modality = self.map_config_to_modality(mm_item_config)
# Check that modality count is less than limit per prompt
......@@ -849,7 +837,7 @@ class RandomMultiModalDataset(RandomDataset):
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
bucket_config: dict[tuple[int, int, int], float] =
bucket_config: dict[tuple[int, int, int], float] =
DEFAULT_MM_ITEM_BUCKET_CONFIG,
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
**kwargs,
......@@ -857,7 +845,7 @@ class RandomMultiModalDataset(RandomDataset):
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
if any(self.map_config_to_modality(cfg) == "video" and p > 0
if any(self.map_config_to_modality(cfg) == "video" and p > 0
for cfg, p in bucket_config.items()):
raise NotImplementedError("Video sampling not implemented; "
"set its probability to 0.")
......@@ -908,7 +896,7 @@ class RandomMultiModalDataset(RandomDataset):
])
if enable_multimodal_chat:
# NOTE: For now this option is only provided for completeness
# NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it.
mm_chat_prompt: Any = prompt
mm_chat_prompt = self.apply_multimodal_chat_transformation(
......@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
entry["conversations"][1]["value"],
)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
lora_request = self.get_random_lora_request(
max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
......@@ -994,11 +982,11 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len
is not None):
continue
if image_path := entry.get("image"):
mm_content = process_image(image_path)
elif video_path := entry.get("video"):
if image_path := entry.get("image"):
mm_content = process_image(image_path)
elif video_path := entry.get("video"):
mm_content = process_video(video_path)
else:
else:
mm_content = None
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
......@@ -1013,9 +1001,9 @@ class ShareGPTDataset(BenchmarkDataset):
request_id=request_id_prefix + str(ind),
))
ind += 1
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
no_oversample)
return samples
......@@ -1024,11 +1012,11 @@ class _ValidateDatasetArgs(argparse.Action):
"""Argparse action to validate dataset name and path compatibility."""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
# Get current values of both dataset_name and dataset_path
dataset_name = getattr(namespace, 'dataset_name', 'random')
dataset_path = getattr(namespace, 'dataset_path', None)
# Validate the combination
if dataset_name == "random" and dataset_path is not None:
parser.error(
......@@ -1053,7 +1041,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
default="random",
action=_ValidateDatasetArgs,
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition", "spec_bench"
],
help="Name of the dataset to benchmark on.",
......@@ -1502,7 +1490,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"spec_bench":
lambda: SpecBench(dataset_path=args.dataset_path,
lambda: SpecBench(dataset_path=args.dataset_path,
category=args.spec_bench_category).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
......@@ -1660,7 +1648,7 @@ class CustomDataset(BenchmarkDataset):
logger.info("num_requests is set to 0 or negative, "
"so using all available samples: %d",
num_requests)
sampled_requests = []
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
......@@ -1686,7 +1674,7 @@ class CustomDataset(BenchmarkDataset):
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -1700,7 +1688,7 @@ class CustomDataset(BenchmarkDataset):
class SpecBench(CustomDataset):
"""
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
Download the dataset using:
Download the dataset using:
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
""" # noqa: E501
......@@ -1736,8 +1724,8 @@ class SpecBench(CustomDataset):
# leverage CustomDataset sample
kwargs["skip_chat_template"] = False
return super().sample(**kwargs)
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
......@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
for i in range(num_requests):
input_len = int(data[i][2])
output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
lora_req = self.get_random_lora_request(
max_loras=max_loras, lora_path=lora_path)
vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
......@@ -1995,7 +1983,7 @@ class ConversationDataset(HuggingFaceDataset):
request_id=request_id_prefix + str(ind),
))
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2055,7 +2043,7 @@ class VisionArenaDataset(HuggingFaceDataset):
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2172,7 +2160,7 @@ class InstructCoderDataset(HuggingFaceDataset):
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2234,7 +2222,7 @@ class MTBenchDataset(HuggingFaceDataset):
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2288,8 +2276,8 @@ class BlazeditDataset(HuggingFaceDataset):
# compare the levenshtein distance normalized by code length
if norm_distance < min_distance or norm_distance > max_distance:
continue
# template copied from
# template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file.
......@@ -2322,9 +2310,9 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
expected_output_len=output_len,
multi_modal_data=None,
request_id=request_id_prefix + str(ind),
))
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
......@@ -2470,9 +2457,9 @@ class NextEditPredictionDataset(HuggingFaceDataset):
))
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
self.maybe_oversample_requests(samples,
num_requests,
request_id_prefix,
no_oversample)
return samples
......@@ -2562,7 +2549,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2647,7 +2634,7 @@ class MLPerfDataset(HuggingFaceDataset):
)
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix, no_oversample)
return sampled_requests
......@@ -2658,7 +2645,7 @@ class MLPerfDataset(HuggingFaceDataset):
class PrefixRepetitionRandomDataset(BenchmarkDataset):
# Default values copied from benchmark_serving.py for the repeated prefix
# Default values copied from benchmark_serving.py for the repeated prefix
# dataset.
DEFAULT_PREFIX_LEN = 256
DEFAULT_SUFFIX_LEN = 256
......
......@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
async def get_tokenizer_async(self) -> AnyTokenizer:
return self.get_tokenizer()
async def add_request_async(
self,
......@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
......@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self.engine.get_tokenizer_async(lora_request)
async def get_tokenizer(self) -> AnyTokenizer:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""
......
......@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
......@@ -186,7 +185,7 @@ class LLMEngine:
return outputs_
tokenizer: Optional[TokenizerGroup]
tokenizer: Optional[AnyTokenizer]
def __init__(
self,
......@@ -233,18 +232,9 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
else:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = (
......@@ -389,10 +379,8 @@ class LLMEngine:
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
self.reasoner if self.decoding_config.reasoning_backend
and self.tokenizer else None,
),
......@@ -521,24 +509,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
lora_config=self.lora_config)
def _init_tokenizer(self) -> AnyTokenizer:
return init_tokenizer_from_configs(model_config=self.model_config)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -574,11 +553,11 @@ class LLMEngine:
)
return None
self._validate_model_inputs(processed_inputs, lora_request)
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
eos_token_id = self.input_preprocessor.get_eos_token_id()
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
......@@ -700,7 +679,6 @@ class LLMEngine:
processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
)
self._add_processed_request(
......@@ -1739,29 +1717,22 @@ class LLMEngine:
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
def _validate_model_inputs(self, inputs: ProcessorInputs):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = (None if self.tokenizer is None else
self.tokenizer.get_lora_tokenizer(lora_request))
tokenizer = self.tokenizer
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids:
......@@ -1822,7 +1793,7 @@ class LLMEngine:
logits_processors = []
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
tokenizer = self.get_tokenizer()
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
......@@ -1835,7 +1806,7 @@ class LLMEngine:
sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)
......
......@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.sequence import SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
......@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
......@@ -20,12 +19,10 @@ class StopChecker:
def __init__(
self,
max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
reasoner: Optional[ReasoningParser] = None,
):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.reasoner = reasoner
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
......
......@@ -76,8 +76,7 @@ class EngineClient(ABC):
include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
tokenizer = preprocessor.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt):
......@@ -260,11 +259,8 @@ class EngineClient(ABC):
...
@abstractmethod
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
async def get_tokenizer(self) -> AnyTokenizer:
"""Get the tokenizer"""
...
async def get_io_processor(self) -> IOProcessor:
......
......@@ -301,23 +301,17 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
lora_request)
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer()
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group()
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
tokenizer_group.tokenizer = tokenizer
self.llm_engine.tokenizer = tokenizer
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
......@@ -707,7 +701,6 @@ class LLM:
self,
messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]],
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
......@@ -739,7 +732,7 @@ class LLM:
cast(list[ChatCompletionMessageParam], messages)
]
tokenizer = self.get_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
......@@ -872,7 +865,6 @@ class LLM:
prompts = self.preprocess_chat(
messages=messages,
lora_request=lora_request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
add_generation_prompt=add_generation_prompt,
......@@ -1519,7 +1511,7 @@ class LLM:
):
"""
Validate that if any multi-modal data is skipped (i.e. None),
then its corresponding UUID must be set.
then its corresponding UUID must be set.
"""
if multi_modal_data is None:
return
......
......@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
tool_parser = self.tool_parser
......
......@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
return None
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)
ctx.tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
......
......@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
......
......@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
......@@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing):
) -> Optional[ErrorResponse]:
"""Collect and aggregate batch results
with support for chunked processing.
For chunked requests, performs online aggregation to
For chunked requests, performs online aggregation to
minimize memory usage.
For regular requests, collects results normally.
"""
......
......@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
......
......@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony:
messages, request_prompts, engine_prompts = (
......
......@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
......
......@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
......@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(request_id,
request.tokens,
......
......@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType,
......@@ -31,7 +29,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
......@@ -42,32 +40,28 @@ class InputPreprocessor:
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")
return self.tokenizer
def get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_bos_token_id(self) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
return self.tokenizer.bos_token_id
def get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
def get_eos_token_id(self) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
"""
......@@ -190,14 +184,13 @@ class InputPreprocessor:
def _tokenize_prompt(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config
......@@ -205,50 +198,39 @@ class InputPreprocessor:
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
return tokenizer.encode(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
async def _tokenize_prompt_async(
self,
prompt: str,
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer_group()
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
def _get_mm_tokenizer(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return tokenizer_group.get_lora_tokenizer(lora_request)
tokenizer = self.get_tokenizer()
return tokenizer
async def _get_mm_tokenizer_async(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
tokenizer = self.get_tokenizer()
return tokenizer
def _process_multimodal(
self,
......@@ -256,7 +238,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
......@@ -264,7 +245,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
tokenizer = self._get_mm_tokenizer(lora_request)
tokenizer = self._get_mm_tokenizer()
mm_processor = self.mm_registry.create_processor(
self.model_config,
......@@ -299,7 +280,6 @@ class InputPreprocessor:
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
......@@ -307,7 +287,7 @@ class InputPreprocessor:
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
tokenizer = await self._get_mm_tokenizer_async()
mm_processor = self.mm_registry.create_processor(
self.model_config,
......@@ -386,7 +366,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -400,7 +379,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
......@@ -415,7 +393,6 @@ class InputPreprocessor:
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -429,7 +406,6 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
......@@ -444,7 +420,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -457,13 +432,11 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
......@@ -480,7 +453,6 @@ class InputPreprocessor:
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -493,13 +465,11 @@ class InputPreprocessor:
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
......@@ -516,7 +486,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
......@@ -526,7 +495,6 @@ class InputPreprocessor:
Arguments:
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
......@@ -539,21 +507,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens":
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -563,7 +528,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
......@@ -578,21 +542,18 @@ class InputPreprocessor:
if parsed["type"] == "tokens":
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -844,7 +805,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
......@@ -856,7 +816,6 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* lora_request
Returns:
......@@ -866,7 +825,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -876,7 +834,6 @@ class InputPreprocessor:
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
......@@ -887,7 +844,6 @@ class InputPreprocessor:
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -897,7 +853,6 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
......@@ -919,7 +874,6 @@ class InputPreprocessor:
return self._process_decoder_only_prompt(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......@@ -927,7 +881,6 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
......@@ -952,7 +905,6 @@ class InputPreprocessor:
return await self._process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_uuids=mm_uuids,
)
......
......@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import TokenizerGroup
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: TokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[
......@@ -32,9 +27,9 @@ class Detokenizer:
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
"""
......@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
......@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
......@@ -111,7 +105,6 @@ class Detokenizer:
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
......@@ -119,14 +112,14 @@ class Detokenizer:
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......@@ -150,7 +143,7 @@ class Detokenizer:
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment