Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c47f6bf
Unverified
Commit
6c47f6bf
authored
Sep 17, 2025
by
Zhuohan Li
Committed by
GitHub
Sep 17, 2025
Browse files
[Core] Remove tokenizer group in vLLM (#24078)
Signed-off-by:
Zhuohan Li
<
zhuohan123@gmail.com
>
parent
c15309a7
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
155 additions
and
284 deletions
+155
-284
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+5
-5
tests/v1/engine/utils.py
tests/v1/engine/utils.py
+2
-4
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+1
-1
vllm/benchmarks/datasets.py
vllm/benchmarks/datasets.py
+80
-93
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+4
-11
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+14
-43
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+2
-4
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+1
-4
vllm/engine/protocol.py
vllm/engine/protocol.py
+3
-7
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+6
-14
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-1
vllm/entrypoints/openai/serving_classification.py
vllm/entrypoints/openai/serving_classification.py
+1
-4
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-2
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+3
-4
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+1
-2
vllm/entrypoints/openai/serving_responses.py
vllm/entrypoints/openai/serving_responses.py
+1
-1
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+1
-1
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+2
-2
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+18
-66
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+8
-15
No files found.
tests/v1/engine/test_output_processor.py
View file @
6c47f6bf
...
@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
...
@@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
[
RequestOutputKind
.
DELTA
,
RequestOutputKind
.
FINAL_ONLY
])
def
test_incremental_detokenization
(
request_output_kind
:
RequestOutputKind
,
def
test_incremental_detokenization
(
request_output_kind
:
RequestOutputKind
,
dummy_test_vectors
):
dummy_test_vectors
):
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
_group
,
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
,
log_stats
=
False
)
log_stats
=
False
)
engine_core
=
MockEngineCore
(
engine_core
=
MockEngineCore
(
tokens_list
=
dummy_test_vectors
.
generation_tokens
)
tokens_list
=
dummy_test_vectors
.
generation_tokens
)
...
@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
...
@@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs
:
Optional
[
int
],
num_sample_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
dummy_test_vectors
):
dummy_test_vectors
):
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
_group
,
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
,
log_stats
=
False
)
log_stats
=
False
)
engine_core
=
MockEngineCore
(
engine_core
=
MockEngineCore
(
tokens_list
=
dummy_test_vectors
.
generation_tokens
,
tokens_list
=
dummy_test_vectors
.
generation_tokens
,
...
@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
...
@@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
)
# '<|end_of_text|>'
)
# '<|end_of_text|>'
stop_token_ids
=
[
128009
]
if
not
is_eos_test
else
None
# '<|eot_id|>'
stop_token_ids
=
[
128009
]
if
not
is_eos_test
else
None
# '<|eot_id|>'
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
_group
,
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
,
log_stats
=
False
)
log_stats
=
False
)
# Dummy engine core outputs, with control tokens suffixed to test stops
# Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token
=
([
eos_token_id
]
if
is_eos_test
else
stop_token_ids
)
suffix_token
=
([
eos_token_id
]
if
is_eos_test
else
stop_token_ids
)
...
@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
...
@@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
[
None
,
NUM_SAMPLE_LOGPROBS_UNDER_TEST
])
[
None
,
NUM_SAMPLE_LOGPROBS_UNDER_TEST
])
def
test_stop_string
(
include_stop_str_in_output
:
bool
,
def
test_stop_string
(
include_stop_str_in_output
:
bool
,
num_sample_logprobs
:
Optional
[
int
],
dummy_test_vectors
):
num_sample_logprobs
:
Optional
[
int
],
dummy_test_vectors
):
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
_group
,
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
,
log_stats
=
False
)
log_stats
=
False
)
engine_core
=
MockEngineCore
(
engine_core
=
MockEngineCore
(
tokens_list
=
dummy_test_vectors
.
generation_tokens
,
tokens_list
=
dummy_test_vectors
.
generation_tokens
,
...
@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
...
@@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
def
test_iteration_stats
(
dummy_test_vectors
):
def
test_iteration_stats
(
dummy_test_vectors
):
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
_group
,
output_processor
=
OutputProcessor
(
dummy_test_vectors
.
tokenizer
,
log_stats
=
True
)
log_stats
=
True
)
engine_core
=
MockEngineCore
(
dummy_test_vectors
.
generation_tokens
)
engine_core
=
MockEngineCore
(
dummy_test_vectors
.
generation_tokens
)
engine_core_timestamp
=
time
.
monotonic
()
engine_core_timestamp
=
time
.
monotonic
()
...
...
tests/v1/engine/utils.py
View file @
6c47f6bf
...
@@ -9,7 +9,6 @@ import torch
...
@@ -9,7 +9,6 @@ import torch
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.v1.engine
import
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine
import
EngineCoreOutput
,
FinishReason
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
...
@@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector(
...
@@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector(
upper
:
float
,
upper
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Create a random vector of top logprob float values.
"""Create a random vector of top logprob float values.
Use to create fake sample logprobs for testing.
Use to create fake sample logprobs for testing.
Note that a real production scenario would require
Note that a real production scenario would require
...
@@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix(
...
@@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix(
upper
:
float
,
upper
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Create a random matrix of top logprob float values.
"""Create a random matrix of top logprob float values.
Use to create fake prompt logprobs for testing.
Use to create fake prompt logprobs for testing.
Note that a real production scenario would require
Note that a real production scenario would require
...
@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
...
@@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
class
DummyOutputProcessorTestVectors
:
class
DummyOutputProcessorTestVectors
:
"""Dummy test vectors for output processor tests"""
"""Dummy test vectors for output processor tests"""
tokenizer
:
GeneralTokenizerType
tokenizer
:
GeneralTokenizerType
tokenizer_group
:
TokenizerGroup
vllm_config
:
EngineArgs
vllm_config
:
EngineArgs
full_tokens
:
list
[
list
[
int
]]
# Prompt + generated tokens
full_tokens
:
list
[
list
[
int
]]
# Prompt + generated tokens
prompt_tokens
:
list
[
list
[
int
]]
prompt_tokens
:
list
[
list
[
int
]]
...
...
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
6c47f6bf
...
@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
...
@@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser
=
reasoning_parser
,
reasoning_parser
=
reasoning_parser
,
speculative_config
=
speculative_config
,
speculative_config
=
speculative_config
,
)
)
tokenizer
=
llm
.
get_tokenizer
(
None
)
tokenizer
=
llm
.
get_tokenizer
()
reasoner
=
ReasoningParserManager
.
get_reasoning_parser
(
reasoning_parser
)(
reasoner
=
ReasoningParserManager
.
get_reasoning_parser
(
reasoning_parser
)(
tokenizer
=
tokenizer
)
tokenizer
=
tokenizer
)
...
...
vllm/benchmarks/datasets.py
View file @
6c47f6bf
...
@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
...
@@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
from
vllm.lora.utils
import
get_adapter_absolute_path
from
vllm.lora.utils
import
get_adapter_absolute_path
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_lora_tokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
PlaceholderModule
from
vllm.utils
import
PlaceholderModule
try
:
try
:
...
@@ -100,8 +100,8 @@ class BenchmarkDataset(ABC):
...
@@ -100,8 +100,8 @@ class BenchmarkDataset(ABC):
)
->
None
:
)
->
None
:
"""
"""
Initialize the BenchmarkDataset with an optional dataset path and random
Initialize the BenchmarkDataset with an optional dataset path and random
seed.
seed.
Args:
Args:
dataset_path (Optional[str]): Path to the dataset. If None, it
dataset_path (Optional[str]): Path to the dataset. If None, it
indicates that a default or random dataset might be used.
indicates that a default or random dataset might be used.
...
@@ -133,10 +133,10 @@ class BenchmarkDataset(ABC):
...
@@ -133,10 +133,10 @@ class BenchmarkDataset(ABC):
elif
isinstance
(
mm_content
,
dict
):
elif
isinstance
(
mm_content
,
dict
):
content
.
append
(
mm_content
)
content
.
append
(
mm_content
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"Could not process multimodal content of type: "
+
"Could not process multimodal content of type: "
+
f
"
{
type
(
mm_content
)
}
"
f
"
{
type
(
mm_content
)
}
"
)
)
return
[{
"role"
:
"user"
,
"content"
:
content
}]
return
[{
"role"
:
"user"
,
"content"
:
content
}]
def
load_data
(
self
)
->
None
:
def
load_data
(
self
)
->
None
:
...
@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
...
@@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
def
get_random_lora_request
(
def
get_random_lora_request
(
self
,
self
,
tokenizer
:
PreTrainedTokenizerBase
,
max_loras
:
Optional
[
int
]
=
None
,
max_loras
:
Optional
[
int
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
)
->
tuple
[
Optional
[
LoRARequest
]
,
AnyTokenizer
]
:
)
->
Optional
[
LoRARequest
]:
"""
"""
Optionally select a random LoRA request and return its associated
Optionally select a random LoRA request.
tokenizer.
This method is used when LoRA parameters are provided. It randomly
This method is used when LoRA parameters are provided. It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for
selects a LoRA based on max_loras.
that LoRA if available. Otherwise, it returns the base tokenizer.
Args:
Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected.
max_loras (Optional[int]): The maximum number of LoRAs available.
max_loras (Optional[int]): The maximum number of LoRAs available.
If `None`, LoRA is not used.
If `None`, LoRA is not used.
lora_path (Optional[str]): Path to the LoRA parameters on disk.
lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used.
If `None`, LoRA is not used.
Returns:
Returns:
A tuple with the following elements:
A new [LoRARequest][] (or `None` if not applicable).
- A new [LoRARequest][] (or `None` if not applicable).
- The tokenizer associated with the LoRA request
(or the base tokenizer).
"""
"""
if
max_loras
is
None
or
lora_path
is
None
:
if
max_loras
is
None
or
lora_path
is
None
:
return
None
,
tokenizer
return
None
# Generate a random LoRA ID in the range [1, max_loras].
# Generate a random LoRA ID in the range [1, max_loras].
lora_id
=
random
.
randint
(
1
,
max_loras
)
lora_id
=
random
.
randint
(
1
,
max_loras
)
...
@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
...
@@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
lora_int_id
=
lora_id
,
lora_int_id
=
lora_id
,
lora_path
=
lora_path_on_disk
(
lora_path
),
lora_path
=
lora_path_on_disk
(
lora_path
),
)
)
if
lora_id
not
in
lora_tokenizer_cache
:
return
lora_request
lora_tokenizer_cache
[
lora_id
]
=
get_lora_tokenizer
(
lora_request
)
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return
lora_request
,
lora_tokenizer_cache
[
lora_id
]
or
tokenizer
@
abstractmethod
@
abstractmethod
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
...
@@ -213,7 +201,7 @@ class BenchmarkDataset(ABC):
...
@@ -213,7 +201,7 @@ class BenchmarkDataset(ABC):
for processing the dataset's text.
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
request_id_prefix (str) The prefix of request_id.
Returns:
Returns:
list[SampleRequest]: A list of sample requests generated from the
list[SampleRequest]: A list of sample requests generated from the
...
@@ -527,7 +515,7 @@ class RandomDataset(BenchmarkDataset):
...
@@ -527,7 +515,7 @@ class RandomDataset(BenchmarkDataset):
size
=
num_requests
)
size
=
num_requests
)
output_lens
=
self
.
_rng
.
integers
(
output_low
,
output_high
+
1
,
output_lens
=
self
.
_rng
.
integers
(
output_low
,
output_high
+
1
,
size
=
num_requests
)
size
=
num_requests
)
offsets
=
self
.
_rng
.
integers
(
0
,
tokenizer
.
vocab_size
,
offsets
=
self
.
_rng
.
integers
(
0
,
tokenizer
.
vocab_size
,
size
=
num_requests
)
size
=
num_requests
)
return
input_lens
,
output_lens
,
offsets
return
input_lens
,
output_lens
,
offsets
...
@@ -555,7 +543,7 @@ class RandomDataset(BenchmarkDataset):
...
@@ -555,7 +543,7 @@ class RandomDataset(BenchmarkDataset):
the encoded sequence is truncated before being decoded again.
the encoded sequence is truncated before being decoded again.
"""
"""
# Build the inner sequence by sampling sequentially from the vocab
# Build the inner sequence by sampling sequentially from the vocab
inner_seq
=
((
offset
+
index
+
np
.
arange
(
input_len
))
inner_seq
=
((
offset
+
index
+
np
.
arange
(
input_len
))
%
vocab_size
).
tolist
()
%
vocab_size
).
tolist
()
token_sequence
=
prefix_token_ids
+
inner_seq
token_sequence
=
prefix_token_ids
+
inner_seq
...
@@ -590,9 +578,9 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -590,9 +578,9 @@ class RandomMultiModalDataset(RandomDataset):
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits.
The maximum is further clamped to the sum of per-modality limits.
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) → probability. We treat
mapping (height, width, num_frames) → probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the
When a modality reaches its cap, all of its buckets are excluded and the
...
@@ -600,8 +588,8 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -600,8 +588,8 @@ class RandomMultiModalDataset(RandomDataset):
Example bucket configuration:
Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
(`num_frames`=16).
OBS.: Only image sampling is supported for now.
OBS.: Only image sampling is supported for now.
"""
"""
...
@@ -624,9 +612,9 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -624,9 +612,9 @@ class RandomMultiModalDataset(RandomDataset):
def
generate_synthetic_image
(
self
,
width
:
int
,
height
:
int
)
->
Image
.
Image
:
def
generate_synthetic_image
(
self
,
width
:
int
,
height
:
int
)
->
Image
.
Image
:
"""Generate synthetic PIL image with random RGB values.
"""Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
(good for stressing I/O), but very unlike real photos.
We could consider a “low-freq” mode (e.g., noise blur)
We could consider a “low-freq” mode (e.g., noise blur)
to emulate network realism instead of max stress.
to emulate network realism instead of max stress.
"""
"""
...
@@ -638,11 +626,11 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -638,11 +626,11 @@ class RandomMultiModalDataset(RandomDataset):
)
)
return
Image
.
fromarray
(
random_pixels
)
return
Image
.
fromarray
(
random_pixels
)
def
generate_synthetic_video
(
self
,
width
:
int
,
def
generate_synthetic_video
(
self
,
width
:
int
,
height
:
int
,
height
:
int
,
num_frames
:
int
)
->
Any
:
num_frames
:
int
)
->
Any
:
"""Generate synthetic video with random values.
"""Generate synthetic video with random values.
TODO: Finish this method.
TODO: Finish this method.
"""
"""
raise
NotImplementedError
(
"Video sampling is WIP."
)
raise
NotImplementedError
(
"Video sampling is WIP."
)
...
@@ -656,7 +644,7 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -656,7 +644,7 @@ class RandomMultiModalDataset(RandomDataset):
else
:
else
:
raise
ValueError
(
f
"Invalid multimodal item configuration:
{
config
}
"
)
raise
ValueError
(
f
"Invalid multimodal item configuration:
{
config
}
"
)
def
normalize_bucket_config
(
self
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
def
normalize_bucket_config
(
self
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
])
->
dict
[
tuple
[
int
,
int
,
int
],
float
]:
float
])
->
dict
[
tuple
[
int
,
int
,
int
],
float
]:
"""
"""
Remove zero probability entries
Remove zero probability entries
...
@@ -676,24 +664,24 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -676,24 +664,24 @@ class RandomMultiModalDataset(RandomDataset):
return
{
k
:
v
/
total
for
k
,
v
in
bucket_config
.
items
()}
return
{
k
:
v
/
total
for
k
,
v
in
bucket_config
.
items
()}
def
generate_mm_item
(
self
,
def
generate_mm_item
(
self
,
mm_item_config
:
tuple
[
int
,
int
,
int
],
mm_item_config
:
tuple
[
int
,
int
,
int
],
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""
"""
Create synthetic images and videos and
Create synthetic images and videos and
apply process_image/process_video respectively.
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
https://github.com/openai/openai-python
"""
"""
if
self
.
map_config_to_modality
(
mm_item_config
)
==
"image"
:
if
self
.
map_config_to_modality
(
mm_item_config
)
==
"image"
:
return
process_image
(
self
.
generate_synthetic_image
(
return
process_image
(
self
.
generate_synthetic_image
(
mm_item_config
[
1
],
mm_item_config
[
1
],
mm_item_config
[
0
]))
mm_item_config
[
0
]))
elif
self
.
map_config_to_modality
(
mm_item_config
)
==
"video"
:
elif
self
.
map_config_to_modality
(
mm_item_config
)
==
"video"
:
return
process_video
(
self
.
generate_synthetic_video
(
return
process_video
(
self
.
generate_synthetic_video
(
mm_item_config
[
1
],
mm_item_config
[
1
],
mm_item_config
[
0
],
mm_item_config
[
0
],
mm_item_config
[
2
]))
mm_item_config
[
2
]))
else
:
else
:
raise
ValueError
(
f
"Invalid multimodal item configuration: "
raise
ValueError
(
f
"Invalid multimodal item configuration: "
...
@@ -723,17 +711,17 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -723,17 +711,17 @@ class RandomMultiModalDataset(RandomDataset):
f
"limit_mm_per_prompt: "
f
"limit_mm_per_prompt: "
f
"
{
limit_mm_per_prompt
.
keys
()
}
"
)
f
"
{
limit_mm_per_prompt
.
keys
()
}
"
)
# Remove zero probability entries
# Remove zero probability entries
# and normalize bucket config to sum to 1
# and normalize bucket config to sum to 1
bucket_config
=
self
.
normalize_bucket_config
(
bucket_config
)
bucket_config
=
self
.
normalize_bucket_config
(
bucket_config
)
logger
.
info
(
logger
.
info
(
"Normalized bucket config: %s"
,
bucket_config
,
"Normalized bucket config: %s"
,
bucket_config
,
)
)
# Only consider limit per prompt for modalities in bucket config
# Only consider limit per prompt for modalities in bucket config
allowed_modalities
=
{
self
.
map_config_to_modality
(
cfg
)
allowed_modalities
=
{
self
.
map_config_to_modality
(
cfg
)
for
cfg
in
bucket_config
}
for
cfg
in
bucket_config
}
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
k
:
v
for
k
,
v
in
limit_mm_per_prompt
.
items
()
k
:
v
for
k
,
v
in
limit_mm_per_prompt
.
items
()
if
k
in
allowed_modalities
}
if
k
in
allowed_modalities
}
if
not
limit_mm_per_prompt
:
if
not
limit_mm_per_prompt
:
raise
ValueError
(
"No valid limits for modalities present in "
raise
ValueError
(
"No valid limits for modalities present in "
...
@@ -746,19 +734,19 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -746,19 +734,19 @@ class RandomMultiModalDataset(RandomDataset):
# Get max and min num mm items and ensure
# Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities
# it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items
=
min
(
max_num_mm_items
=
min
(
sum
(
limit_mm_per_prompt
.
values
()),
sum
(
limit_mm_per_prompt
.
values
()),
math
.
ceil
(
base_items_per_request
*
(
1
+
num_mm_items_range_ratio
))
math
.
ceil
(
base_items_per_request
*
(
1
+
num_mm_items_range_ratio
))
)
)
# Ensure min num mm items is at least 0
# Ensure min num mm items is at least 0
min_num_mm_items
=
max
(
min_num_mm_items
=
max
(
0
,
0
,
math
.
floor
(
base_items_per_request
*
(
1
-
num_mm_items_range_ratio
))
math
.
floor
(
base_items_per_request
*
(
1
-
num_mm_items_range_ratio
))
)
)
# Raise error if min num mm items is greater than max num mm items
# Raise error if min num mm items is greater than max num mm items
if
min_num_mm_items
>
max_num_mm_items
:
if
min_num_mm_items
>
max_num_mm_items
:
raise
ValueError
(
f
"Min num mm items is greater than max mm items: "
raise
ValueError
(
f
"Min num mm items is greater than max mm items: "
f
"
{
min_num_mm_items
}
>
{
max_num_mm_items
}
"
)
f
"
{
min_num_mm_items
}
>
{
max_num_mm_items
}
"
)
logger
.
info
(
logger
.
info
(
"Sampling number of multimodal items from [%s, %s]"
,
"Sampling number of multimodal items from [%s, %s]"
,
min_num_mm_items
,
max_num_mm_items
,
min_num_mm_items
,
max_num_mm_items
,
...
@@ -783,8 +771,8 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -783,8 +771,8 @@ class RandomMultiModalDataset(RandomDataset):
whose size is between min_num_mm_items and max_num_mm_items.
whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item.
Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached.
for all modalities is reached.
Note:
Note:
...
@@ -796,19 +784,19 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -796,19 +784,19 @@ class RandomMultiModalDataset(RandomDataset):
# Get the number of multimodal items to sample
# Get the number of multimodal items to sample
request_num_mm_items
=
int
(
request_num_mm_items
=
int
(
self
.
_rng
.
integers
(
min_num_mm_items
,
max_num_mm_items
+
1
)
self
.
_rng
.
integers
(
min_num_mm_items
,
max_num_mm_items
+
1
)
)
)
# If request_num_mm_items is 0, yield an empty iterator
# If request_num_mm_items is 0, yield an empty iterator
if
request_num_mm_items
==
0
:
if
request_num_mm_items
==
0
:
return
return
# Initialize modality counters
# Initialize modality counters
modality_counter
=
{
self
.
map_config_to_modality
(
k
):
0
modality_counter
=
{
self
.
map_config_to_modality
(
k
):
0
for
k
in
bucket_config
}
for
k
in
bucket_config
}
# Copy the bucket config to avoid modifying the original
# Copy the bucket config to avoid modifying the original
bucket_config_copy
=
bucket_config
.
copy
()
bucket_config_copy
=
bucket_config
.
copy
()
# Loop over the number of multimodal items to sample
# Loop over the number of multimodal items to sample
while
sum
(
modality_counter
.
values
())
<
request_num_mm_items
:
while
sum
(
modality_counter
.
values
())
<
request_num_mm_items
:
# Sample a multimodal item config
# Sample a multimodal item config
mm_item_config
=
self
.
_rng
.
choice
(
list
(
bucket_config_copy
.
keys
()),
mm_item_config
=
self
.
_rng
.
choice
(
list
(
bucket_config_copy
.
keys
()),
p
=
list
(
bucket_config_copy
.
values
()))
p
=
list
(
bucket_config_copy
.
values
()))
modality
=
self
.
map_config_to_modality
(
mm_item_config
)
modality
=
self
.
map_config_to_modality
(
mm_item_config
)
# Check that modality count is less than limit per prompt
# Check that modality count is less than limit per prompt
...
@@ -849,7 +837,7 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -849,7 +837,7 @@ class RandomMultiModalDataset(RandomDataset):
limit_mm_per_prompt
:
dict
[
str
,
int
]
=
DEFAULT_LIMIT_MM_PER_PROMPT
,
limit_mm_per_prompt
:
dict
[
str
,
int
]
=
DEFAULT_LIMIT_MM_PER_PROMPT
,
base_items_per_request
:
int
=
DEFAULT_BASE_ITEMS_PER_REQUEST
,
base_items_per_request
:
int
=
DEFAULT_BASE_ITEMS_PER_REQUEST
,
num_mm_items_range_ratio
:
float
=
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO
,
num_mm_items_range_ratio
:
float
=
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO
,
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
]
=
bucket_config
:
dict
[
tuple
[
int
,
int
,
int
],
float
]
=
DEFAULT_MM_ITEM_BUCKET_CONFIG
,
DEFAULT_MM_ITEM_BUCKET_CONFIG
,
enable_multimodal_chat
:
bool
=
DEFAULT_ENABLE_MULTIMODAL_CHAT
,
enable_multimodal_chat
:
bool
=
DEFAULT_ENABLE_MULTIMODAL_CHAT
,
**
kwargs
,
**
kwargs
,
...
@@ -857,7 +845,7 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -857,7 +845,7 @@ class RandomMultiModalDataset(RandomDataset):
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
# and probability is non-zero.
if
any
(
self
.
map_config_to_modality
(
cfg
)
==
"video"
and
p
>
0
if
any
(
self
.
map_config_to_modality
(
cfg
)
==
"video"
and
p
>
0
for
cfg
,
p
in
bucket_config
.
items
()):
for
cfg
,
p
in
bucket_config
.
items
()):
raise
NotImplementedError
(
"Video sampling not implemented; "
raise
NotImplementedError
(
"Video sampling not implemented; "
"set its probability to 0."
)
"set its probability to 0."
)
...
@@ -908,7 +896,7 @@ class RandomMultiModalDataset(RandomDataset):
...
@@ -908,7 +896,7 @@ class RandomMultiModalDataset(RandomDataset):
])
])
if
enable_multimodal_chat
:
if
enable_multimodal_chat
:
# NOTE: For now this option is only provided for completeness
# NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it.
# given that the serve.py benchmark currently does not use it.
mm_chat_prompt
:
Any
=
prompt
mm_chat_prompt
:
Any
=
prompt
mm_chat_prompt
=
self
.
apply_multimodal_chat_transformation
(
mm_chat_prompt
=
self
.
apply_multimodal_chat_transformation
(
...
@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
...
@@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
entry
[
"conversations"
][
1
][
"value"
],
entry
[
"conversations"
][
1
][
"value"
],
)
)
lora_request
,
tokenizer
=
self
.
get_random_lora_request
(
lora_request
=
self
.
get_random_lora_request
(
tokenizer
=
tokenizer
,
max_loras
=
max_loras
,
lora_path
=
lora_path
)
max_loras
=
max_loras
,
lora_path
=
lora_path
)
prompt_ids
=
tokenizer
(
prompt
).
input_ids
prompt_ids
=
tokenizer
(
prompt
).
input_ids
completion_ids
=
tokenizer
(
completion
).
input_ids
completion_ids
=
tokenizer
(
completion
).
input_ids
prompt_len
=
len
(
prompt_ids
)
prompt_len
=
len
(
prompt_ids
)
...
@@ -994,11 +982,11 @@ class ShareGPTDataset(BenchmarkDataset):
...
@@ -994,11 +982,11 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check
=
output_len
skip_min_output_len_check
=
output_len
is
not
None
):
is
not
None
):
continue
continue
if
image_path
:
=
entry
.
get
(
"image"
):
if
image_path
:
=
entry
.
get
(
"image"
):
mm_content
=
process_image
(
image_path
)
mm_content
=
process_image
(
image_path
)
elif
video_path
:
=
entry
.
get
(
"video"
):
elif
video_path
:
=
entry
.
get
(
"video"
):
mm_content
=
process_video
(
video_path
)
mm_content
=
process_video
(
video_path
)
else
:
else
:
mm_content
=
None
mm_content
=
None
if
enable_multimodal_chat
:
if
enable_multimodal_chat
:
prompt
=
self
.
apply_multimodal_chat_transformation
(
prompt
=
self
.
apply_multimodal_chat_transformation
(
...
@@ -1013,9 +1001,9 @@ class ShareGPTDataset(BenchmarkDataset):
...
@@ -1013,9 +1001,9 @@ class ShareGPTDataset(BenchmarkDataset):
request_id
=
request_id_prefix
+
str
(
ind
),
request_id
=
request_id_prefix
+
str
(
ind
),
))
))
ind
+=
1
ind
+=
1
self
.
maybe_oversample_requests
(
samples
,
self
.
maybe_oversample_requests
(
samples
,
num_requests
,
num_requests
,
request_id_prefix
,
request_id_prefix
,
no_oversample
)
no_oversample
)
return
samples
return
samples
...
@@ -1024,11 +1012,11 @@ class _ValidateDatasetArgs(argparse.Action):
...
@@ -1024,11 +1012,11 @@ class _ValidateDatasetArgs(argparse.Action):
"""Argparse action to validate dataset name and path compatibility."""
"""Argparse action to validate dataset name and path compatibility."""
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
setattr
(
namespace
,
self
.
dest
,
values
)
setattr
(
namespace
,
self
.
dest
,
values
)
# Get current values of both dataset_name and dataset_path
# Get current values of both dataset_name and dataset_path
dataset_name
=
getattr
(
namespace
,
'dataset_name'
,
'random'
)
dataset_name
=
getattr
(
namespace
,
'dataset_name'
,
'random'
)
dataset_path
=
getattr
(
namespace
,
'dataset_path'
,
None
)
dataset_path
=
getattr
(
namespace
,
'dataset_path'
,
None
)
# Validate the combination
# Validate the combination
if
dataset_name
==
"random"
and
dataset_path
is
not
None
:
if
dataset_name
==
"random"
and
dataset_path
is
not
None
:
parser
.
error
(
parser
.
error
(
...
@@ -1053,7 +1041,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
...
@@ -1053,7 +1041,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
default
=
"random"
,
default
=
"random"
,
action
=
_ValidateDatasetArgs
,
action
=
_ValidateDatasetArgs
,
choices
=
[
choices
=
[
"sharegpt"
,
"burstgpt"
,
"sonnet"
,
"random"
,
"random-mm"
,
"hf"
,
"sharegpt"
,
"burstgpt"
,
"sonnet"
,
"random"
,
"random-mm"
,
"hf"
,
"custom"
,
"prefix_repetition"
,
"spec_bench"
"custom"
,
"prefix_repetition"
,
"spec_bench"
],
],
help
=
"Name of the dataset to benchmark on."
,
help
=
"Name of the dataset to benchmark on."
,
...
@@ -1502,7 +1490,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
...
@@ -1502,7 +1490,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
# For datasets that follow a similar structure, use a mapping.
# For datasets that follow a similar structure, use a mapping.
dataset_mapping
=
{
dataset_mapping
=
{
"spec_bench"
:
"spec_bench"
:
lambda
:
SpecBench
(
dataset_path
=
args
.
dataset_path
,
lambda
:
SpecBench
(
dataset_path
=
args
.
dataset_path
,
category
=
args
.
spec_bench_category
).
sample
(
category
=
args
.
spec_bench_category
).
sample
(
num_requests
=
args
.
num_prompts
,
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -1660,7 +1648,7 @@ class CustomDataset(BenchmarkDataset):
...
@@ -1660,7 +1648,7 @@ class CustomDataset(BenchmarkDataset):
logger
.
info
(
"num_requests is set to 0 or negative, "
logger
.
info
(
"num_requests is set to 0 or negative, "
"so using all available samples: %d"
,
"so using all available samples: %d"
,
num_requests
)
num_requests
)
sampled_requests
=
[]
sampled_requests
=
[]
for
i
,
item
in
enumerate
(
self
.
data
):
for
i
,
item
in
enumerate
(
self
.
data
):
if
len
(
sampled_requests
)
>=
num_requests
:
if
len
(
sampled_requests
)
>=
num_requests
:
...
@@ -1686,7 +1674,7 @@ class CustomDataset(BenchmarkDataset):
...
@@ -1686,7 +1674,7 @@ class CustomDataset(BenchmarkDataset):
expected_output_len
=
output_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
request_id
=
request_id_prefix
+
str
(
i
),
))
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -1700,7 +1688,7 @@ class CustomDataset(BenchmarkDataset):
...
@@ -1700,7 +1688,7 @@ class CustomDataset(BenchmarkDataset):
class
SpecBench
(
CustomDataset
):
class
SpecBench
(
CustomDataset
):
"""
"""
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
Download the dataset using:
Download the dataset using:
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
"""
# noqa: E501
"""
# noqa: E501
...
@@ -1736,8 +1724,8 @@ class SpecBench(CustomDataset):
...
@@ -1736,8 +1724,8 @@ class SpecBench(CustomDataset):
# leverage CustomDataset sample
# leverage CustomDataset sample
kwargs
[
"skip_chat_template"
]
=
False
kwargs
[
"skip_chat_template"
]
=
False
return
super
().
sample
(
**
kwargs
)
return
super
().
sample
(
**
kwargs
)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
...
@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
...
@@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
input_len
=
int
(
data
[
i
][
2
])
input_len
=
int
(
data
[
i
][
2
])
output_len
=
int
(
data
[
i
][
3
])
output_len
=
int
(
data
[
i
][
3
])
lora_req
,
tokenizer
=
self
.
get_random_lora_request
(
lora_req
=
self
.
get_random_lora_request
(
tokenizer
=
tokenizer
,
max_loras
=
max_loras
,
lora_path
=
lora_path
)
max_loras
=
max_loras
,
lora_path
=
lora_path
)
vocab_size
=
tokenizer
.
vocab_size
vocab_size
=
tokenizer
.
vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
# j) modulo vocab_size.
...
@@ -1995,7 +1983,7 @@ class ConversationDataset(HuggingFaceDataset):
...
@@ -1995,7 +1983,7 @@ class ConversationDataset(HuggingFaceDataset):
request_id
=
request_id_prefix
+
str
(
ind
),
request_id
=
request_id_prefix
+
str
(
ind
),
))
))
ind
+=
1
ind
+=
1
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2055,7 +2043,7 @@ class VisionArenaDataset(HuggingFaceDataset):
...
@@ -2055,7 +2043,7 @@ class VisionArenaDataset(HuggingFaceDataset):
multi_modal_data
=
mm_content
,
multi_modal_data
=
mm_content
,
request_id
=
request_id_prefix
+
str
(
i
),
request_id
=
request_id_prefix
+
str
(
i
),
))
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2172,7 +2160,7 @@ class InstructCoderDataset(HuggingFaceDataset):
...
@@ -2172,7 +2160,7 @@ class InstructCoderDataset(HuggingFaceDataset):
expected_output_len
=
output_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
request_id
=
request_id_prefix
+
str
(
i
),
))
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2234,7 +2222,7 @@ class MTBenchDataset(HuggingFaceDataset):
...
@@ -2234,7 +2222,7 @@ class MTBenchDataset(HuggingFaceDataset):
expected_output_len
=
output_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
request_id
=
request_id_prefix
+
str
(
i
),
))
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2288,8 +2276,8 @@ class BlazeditDataset(HuggingFaceDataset):
...
@@ -2288,8 +2276,8 @@ class BlazeditDataset(HuggingFaceDataset):
# compare the levenshtein distance normalized by code length
# compare the levenshtein distance normalized by code length
if
norm_distance
<
min_distance
or
norm_distance
>
max_distance
:
if
norm_distance
<
min_distance
or
norm_distance
>
max_distance
:
continue
continue
# template copied from
# template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction
=
f
"""Given a code file, please apply the change requests and generate the new file.
instruction
=
f
"""Given a code file, please apply the change requests and generate the new file.
...
@@ -2322,9 +2310,9 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
...
@@ -2322,9 +2310,9 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
expected_output_len
=
output_len
,
expected_output_len
=
output_len
,
request_id
=
request_id_prefix
+
str
(
i
),
request_id
=
request_id_prefix
+
str
(
i
),
))
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
...
@@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
expected_output_len
=
output_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
None
,
multi_modal_data
=
None
,
request_id
=
request_id_prefix
+
str
(
ind
),
request_id
=
request_id_prefix
+
str
(
ind
),
))
))
ind
+=
1
ind
+=
1
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
...
@@ -2470,9 +2457,9 @@ class NextEditPredictionDataset(HuggingFaceDataset):
...
@@ -2470,9 +2457,9 @@ class NextEditPredictionDataset(HuggingFaceDataset):
))
))
if
len
(
samples
)
>=
num_requests
:
if
len
(
samples
)
>=
num_requests
:
break
break
self
.
maybe_oversample_requests
(
samples
,
self
.
maybe_oversample_requests
(
samples
,
num_requests
,
num_requests
,
request_id_prefix
,
request_id_prefix
,
no_oversample
)
no_oversample
)
return
samples
return
samples
...
@@ -2562,7 +2549,7 @@ class ASRDataset(HuggingFaceDataset):
...
@@ -2562,7 +2549,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports."
,
" what Whisper supports."
,
skipped
,
skipped
,
)
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2647,7 +2634,7 @@ class MLPerfDataset(HuggingFaceDataset):
...
@@ -2647,7 +2634,7 @@ class MLPerfDataset(HuggingFaceDataset):
)
)
ind
+=
1
ind
+=
1
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
,
request_id_prefix
,
no_oversample
)
request_id_prefix
,
no_oversample
)
return
sampled_requests
return
sampled_requests
...
@@ -2658,7 +2645,7 @@ class MLPerfDataset(HuggingFaceDataset):
...
@@ -2658,7 +2645,7 @@ class MLPerfDataset(HuggingFaceDataset):
class
PrefixRepetitionRandomDataset
(
BenchmarkDataset
):
class
PrefixRepetitionRandomDataset
(
BenchmarkDataset
):
# Default values copied from benchmark_serving.py for the repeated prefix
# Default values copied from benchmark_serving.py for the repeated prefix
# dataset.
# dataset.
DEFAULT_PREFIX_LEN
=
256
DEFAULT_PREFIX_LEN
=
256
DEFAULT_SUFFIX_LEN
=
256
DEFAULT_SUFFIX_LEN
=
256
...
...
vllm/engine/async_llm_engine.py
View file @
6c47f6bf
...
@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
async
def
get_tokenizer_async
(
self
,
async
def
get_tokenizer_async
(
self
)
->
AnyTokenizer
:
lora_request
:
Optional
[
LoRARequest
]
=
None
return
self
.
get_tokenizer
()
)
->
AnyTokenizer
:
return
await
(
self
.
get_tokenizer_group
().
get_lora_tokenizer_async
(
lora_request
))
async
def
add_request_async
(
async
def
add_request_async
(
self
,
self
,
...
@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
processed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
prompt
,
prompt
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
...
@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
...
@@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
async
def
get_input_preprocessor
(
self
)
->
InputPreprocessor
:
async
def
get_input_preprocessor
(
self
)
->
InputPreprocessor
:
return
self
.
engine
.
input_preprocessor
return
self
.
engine
.
input_preprocessor
async
def
get_tokenizer
(
async
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
self
,
return
self
.
engine
.
get_tokenizer
()
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
return
await
self
.
engine
.
get_tokenizer_async
(
lora_request
)
def
start_background_loop
(
self
)
->
None
:
def
start_background_loop
(
self
)
->
None
:
"""Start the background loop."""
"""Start the background loop."""
...
...
vllm/engine/llm_engine.py
View file @
6c47f6bf
...
@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
...
@@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
from
vllm.transformers_utils.tokenizer_group
import
(
init_tokenizer_from_configs
)
TokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
resolve_obj_by_qualname
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
resolve_obj_by_qualname
,
weak_bind
...
@@ -186,7 +185,7 @@ class LLMEngine:
...
@@ -186,7 +185,7 @@ class LLMEngine:
return
outputs_
return
outputs_
tokenizer
:
Optional
[
Tokenizer
Group
]
tokenizer
:
Optional
[
Any
Tokenizer
]
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -233,18 +232,9 @@ class LLMEngine:
...
@@ -233,18 +232,9 @@ class LLMEngine:
if
self
.
model_config
.
skip_tokenizer_init
:
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
detokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
else
:
else
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def
get_tokenizer_for_seq
(
sequence
:
Sequence
)
->
AnyTokenizer
:
assert
tokenizer_group
,
(
"tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False"
)
return
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
(
self
.
generation_config_fields
=
(
...
@@ -389,10 +379,8 @@ class LLMEngine:
...
@@ -389,10 +379,8 @@ class LLMEngine:
self
.
detokenizer
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
scheduler
,
self
.
seq_counter
,
self
.
seq_counter
,
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
stop_checker
=
StopChecker
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
self
.
reasoner
if
self
.
decoding_config
.
reasoning_backend
self
.
reasoner
if
self
.
decoding_config
.
reasoning_backend
and
self
.
tokenizer
else
None
,
and
self
.
tokenizer
else
None
,
),
),
...
@@ -521,24 +509,15 @@ class LLMEngine:
...
@@ -521,24 +509,15 @@ class LLMEngine:
if
model_executor
:
=
getattr
(
self
,
"model_executor"
,
None
):
if
model_executor
:
=
getattr
(
self
,
"model_executor"
,
None
):
model_executor
.
shutdown
()
model_executor
.
shutdown
()
def
get_tokenizer
_group
(
self
)
->
Tokenizer
Group
:
def
get_tokenizer
(
self
)
->
Any
Tokenizer
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"Unable to get tokenizer because "
raise
ValueError
(
"Unable to get tokenizer because "
"skip_tokenizer_init is True"
)
"skip_tokenizer_init is True"
)
return
self
.
tokenizer
return
self
.
tokenizer
def
get_tokenizer
(
def
_init_tokenizer
(
self
)
->
AnyTokenizer
:
self
,
return
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
)
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
def
_init_tokenizer
(
self
)
->
TokenizerGroup
:
return
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
self
.
scheduler_config
,
lora_config
=
self
.
lora_config
)
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
@@ -574,11 +553,11 @@ class LLMEngine:
...
@@ -574,11 +553,11 @@ class LLMEngine:
)
)
return
None
return
None
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
()
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
...
@@ -700,7 +679,6 @@ class LLMEngine:
...
@@ -700,7 +679,6 @@ class LLMEngine:
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
)
)
self
.
_add_processed_request
(
self
.
_add_processed_request
(
...
@@ -1739,29 +1717,22 @@ class LLMEngine:
...
@@ -1739,29 +1717,22 @@ class LLMEngine:
SpanAttributes
.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
,
SpanAttributes
.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
,
metrics
.
model_execute_time
)
metrics
.
model_execute_time
)
def
_validate_model_inputs
(
self
,
inputs
:
ProcessorInputs
,
def
_validate_model_inputs
(
self
,
inputs
:
ProcessorInputs
):
lora_request
:
Optional
[
LoRARequest
]):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
inputs
)
if
encoder_inputs
is
not
None
:
if
encoder_inputs
is
not
None
:
self
.
_validate_model_input
(
encoder_inputs
,
self
.
_validate_model_input
(
encoder_inputs
,
prompt_type
=
"encoder"
)
lora_request
,
prompt_type
=
"encoder"
)
self
.
_validate_model_input
(
decoder_inputs
,
self
.
_validate_model_input
(
decoder_inputs
,
prompt_type
=
"decoder"
)
lora_request
,
prompt_type
=
"decoder"
)
def
_validate_model_input
(
def
_validate_model_input
(
self
,
self
,
prompt_inputs
:
SingletonInputs
,
prompt_inputs
:
SingletonInputs
,
lora_request
:
Optional
[
LoRARequest
],
*
,
*
,
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
tokenizer
=
(
None
if
self
.
tokenizer
is
None
else
tokenizer
=
self
.
tokenizer
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
))
prompt_ids
=
prompt_inputs
.
get
(
"prompt_token_ids"
,
[])
prompt_ids
=
prompt_inputs
.
get
(
"prompt_token_ids"
,
[])
if
not
prompt_ids
:
if
not
prompt_ids
:
...
@@ -1822,7 +1793,7 @@ class LLMEngine:
...
@@ -1822,7 +1793,7 @@ class LLMEngine:
logits_processors
=
[]
logits_processors
=
[]
if
(
sampling_params
.
logit_bias
or
sampling_params
.
allowed_token_ids
):
if
(
sampling_params
.
logit_bias
or
sampling_params
.
allowed_token_ids
):
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
)
tokenizer
=
self
.
get_tokenizer
()
processors
=
get_openai_logits_processors
(
processors
=
get_openai_logits_processors
(
logit_bias
=
sampling_params
.
logit_bias
,
logit_bias
=
sampling_params
.
logit_bias
,
...
@@ -1835,7 +1806,7 @@ class LLMEngine:
...
@@ -1835,7 +1806,7 @@ class LLMEngine:
sampling_params
.
allowed_token_ids
=
None
sampling_params
.
allowed_token_ids
=
None
if
len
(
sampling_params
.
bad_words
)
>
0
:
if
len
(
sampling_params
.
bad_words
)
>
0
:
tokenizer
=
self
.
get_tokenizer
(
lora_request
)
tokenizer
=
self
.
get_tokenizer
()
processors
=
get_bad_words_logits_processors
(
processors
=
get_bad_words_logits_processors
(
bad_words
=
sampling_params
.
bad_words
,
tokenizer
=
tokenizer
)
bad_words
=
sampling_params
.
bad_words
,
tokenizer
=
tokenizer
)
logits_processors
.
extend
(
processors
)
logits_processors
.
extend
(
processors
)
...
...
vllm/engine/output_processor/interfaces.py
View file @
6c47f6bf
...
@@ -2,14 +2,13 @@
...
@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
List
from
typing
import
List
from
vllm.config
import
SchedulerConfig
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceGroupOutput
from
vllm.sequence
import
SequenceGroup
,
SequenceGroupOutput
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
...
@@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer
:
Detokenizer
,
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
scheduler
:
List
[
Scheduler
],
seq_counter
:
Counter
,
seq_counter
:
Counter
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
AnyTokenizer
],
stop_checker
:
"StopChecker"
,
stop_checker
:
"StopChecker"
,
):
):
"""Create an output processor.
"""Create an output processor.
...
...
vllm/engine/output_processor/stop_checker.py
View file @
6c47f6bf
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.reasoning
import
ReasoningParser
from
vllm.reasoning
import
ReasoningParser
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
class
StopChecker
:
class
StopChecker
:
...
@@ -20,12 +19,10 @@ class StopChecker:
...
@@ -20,12 +19,10 @@ class StopChecker:
def
__init__
(
def
__init__
(
self
,
self
,
max_model_len
:
int
,
max_model_len
:
int
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
AnyTokenizer
],
reasoner
:
Optional
[
ReasoningParser
]
=
None
,
reasoner
:
Optional
[
ReasoningParser
]
=
None
,
):
):
# Do not use it directly, but use `self._get_max_model_len`.
# Do not use it directly, but use `self._get_max_model_len`.
self
.
_max_model_len
=
max_model_len
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
reasoner
=
reasoner
self
.
reasoner
=
reasoner
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
...
...
vllm/engine/protocol.py
View file @
6c47f6bf
...
@@ -76,8 +76,7 @@ class EngineClient(ABC):
...
@@ -76,8 +76,7 @@ class EngineClient(ABC):
include_stop_str_in_output
=
params
.
include_stop_str_in_output
include_stop_str_in_output
=
params
.
include_stop_str_in_output
preprocessor
=
await
self
.
get_input_preprocessor
()
preprocessor
=
await
self
.
get_input_preprocessor
()
tokenizer_group
=
preprocessor
.
get_tokenizer_group
()
tokenizer
=
preprocessor
.
get_tokenizer
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
()
eos_token_id
=
tokenizer
.
eos_token_id
eos_token_id
=
tokenizer
.
eos_token_id
if
is_explicit_encoder_decoder_prompt
(
prompt
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
...
@@ -260,11 +259,8 @@ class EngineClient(ABC):
...
@@ -260,11 +259,8 @@ class EngineClient(ABC):
...
...
@
abstractmethod
@
abstractmethod
async
def
get_tokenizer
(
async
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
self
,
"""Get the tokenizer"""
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
"""Get the appropriate tokenizer for the request"""
...
...
async
def
get_io_processor
(
self
)
->
IOProcessor
:
async
def
get_io_processor
(
self
)
->
IOProcessor
:
...
...
vllm/entrypoints/llm.py
View file @
6c47f6bf
...
@@ -301,23 +301,17 @@ class LLM:
...
@@ -301,23 +301,17 @@ class LLM:
self
.
io_processor
=
get_io_processor
(
self
.
llm_engine
.
vllm_config
,
self
.
io_processor
=
get_io_processor
(
self
.
llm_engine
.
vllm_config
,
io_processor_plugin
)
io_processor_plugin
)
def
get_tokenizer
(
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
self
,
return
self
.
llm_engine
.
get_tokenizer
()
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
return
self
.
llm_engine
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
def
set_tokenizer
(
self
,
tokenizer
:
AnyTokenizer
)
->
None
:
def
set_tokenizer
(
self
,
tokenizer
:
AnyTokenizer
)
->
None
:
tokenizer_group
=
self
.
llm_engine
.
get_tokenizer_group
()
# While CachedTokenizer is dynamic, have no choice but
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
# user-defined tokenizer started with 'Cached'
if
tokenizer
.
__class__
.
__name__
.
startswith
(
"Cached"
):
if
tokenizer
.
__class__
.
__name__
.
startswith
(
"Cached"
):
tokenizer_group
.
tokenizer
=
tokenizer
self
.
llm_engine
.
tokenizer
=
tokenizer
else
:
else
:
tokenizer_group
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
self
.
llm_engine
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
def
get_default_sampling_params
(
self
)
->
SamplingParams
:
def
get_default_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
default_sampling_params
is
None
:
if
self
.
default_sampling_params
is
None
:
...
@@ -707,7 +701,6 @@ class LLM:
...
@@ -707,7 +701,6 @@ class LLM:
self
,
self
,
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
list
[
list
[
ChatCompletionMessageParam
]]],
list
[
list
[
ChatCompletionMessageParam
]]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
...
@@ -739,7 +732,7 @@ class LLM:
...
@@ -739,7 +732,7 @@ class LLM:
cast
(
list
[
ChatCompletionMessageParam
],
messages
)
cast
(
list
[
ChatCompletionMessageParam
],
messages
)
]
]
tokenizer
=
self
.
get_tokenizer
(
lora_request
)
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
model_config
=
self
.
llm_engine
.
get_model_config
()
resolved_content_format
=
resolve_chat_template_content_format
(
resolved_content_format
=
resolve_chat_template_content_format
(
chat_template
,
chat_template
,
...
@@ -872,7 +865,6 @@ class LLM:
...
@@ -872,7 +865,6 @@ class LLM:
prompts
=
self
.
preprocess_chat
(
prompts
=
self
.
preprocess_chat
(
messages
=
messages
,
messages
=
messages
,
lora_request
=
lora_request
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
chat_template_content_format
=
chat_template_content_format
,
chat_template_content_format
=
chat_template_content_format
,
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
...
@@ -1519,7 +1511,7 @@ class LLM:
...
@@ -1519,7 +1511,7 @@ class LLM:
):
):
"""
"""
Validate that if any multi-modal data is skipped (i.e. None),
Validate that if any multi-modal data is skipped (i.e. None),
then its corresponding UUID must be set.
then its corresponding UUID must be set.
"""
"""
if
multi_modal_data
is
None
:
if
multi_modal_data
is
None
:
return
return
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
6c47f6bf
...
@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
model_name
=
self
.
models
.
model_name
(
lora_request
)
model_name
=
self
.
models
.
model_name
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
tool_parser
=
self
.
tool_parser
tool_parser
=
self
.
tool_parser
...
...
vllm/entrypoints/openai/serving_classification.py
View file @
6c47f6bf
...
@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
...
@@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
return
None
return
None
try
:
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
ctx
.
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
renderer
=
self
.
_get_renderer
(
ctx
.
tokenizer
)
renderer
=
self
.
_get_renderer
(
ctx
.
tokenizer
)
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
6c47f6bf
...
@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if
self
.
model_config
.
skip_tokenizer_init
:
if
self
.
model_config
.
skip_tokenizer_init
:
tokenizer
=
None
tokenizer
=
None
else
:
else
:
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
engine_prompts
=
await
renderer
.
render_prompt_and_embeds
(
engine_prompts
=
await
renderer
.
render_prompt_and_embeds
(
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
6c47f6bf
...
@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
...
@@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
try
:
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
...
@@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing):
...
@@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing):
)
->
Optional
[
ErrorResponse
]:
)
->
Optional
[
ErrorResponse
]:
"""Collect and aggregate batch results
"""Collect and aggregate batch results
with support for chunked processing.
with support for chunked processing.
For chunked requests, performs online aggregation to
For chunked requests, performs online aggregation to
minimize memory usage.
minimize memory usage.
For regular requests, collects results normally.
For regular requests, collects results normally.
"""
"""
...
...
vllm/entrypoints/openai/serving_pooling.py
View file @
6c47f6bf
...
@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
if
self
.
model_config
.
skip_tokenizer_init
:
if
self
.
model_config
.
skip_tokenizer_init
:
tokenizer
=
None
tokenizer
=
None
else
:
else
:
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
if
getattr
(
request
,
"dimensions"
,
None
)
is
not
None
:
if
getattr
(
request
,
"dimensions"
,
None
)
is
not
None
:
...
...
vllm/entrypoints/openai/serving_responses.py
View file @
6c47f6bf
...
@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
try
:
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
model_name
=
self
.
models
.
model_name
(
lora_request
)
model_name
=
self
.
models
.
model_name
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
if
self
.
use_harmony
:
if
self
.
use_harmony
:
messages
,
request_prompts
,
engine_prompts
=
(
messages
,
request_prompts
,
engine_prompts
=
(
...
...
vllm/entrypoints/openai/serving_score.py
View file @
6c47f6bf
...
@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
...
@@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
truncate_prompt_tokens
=
getattr
(
request
,
"truncate_prompt_tokens"
,
truncate_prompt_tokens
=
getattr
(
request
,
"truncate_prompt_tokens"
,
None
)
None
)
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
6c47f6bf
...
@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
try
:
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
renderer
=
self
.
_get_renderer
(
tokenizer
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
...
@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
()
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
request
.
tokens
,
...
...
vllm/inputs/preprocess.py
View file @
6c47f6bf
...
@@ -9,13 +9,11 @@ from typing_extensions import assert_never
...
@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.cache
import
BaseMultiModalProcessorCache
from
vllm.multimodal.cache
import
BaseMultiModalProcessorCache
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
,
MultiModalUUIDDict
)
MultiModalInputs
,
MultiModalUUIDDict
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
.data
import
(
DecoderOnlyInputs
,
EmbedsInputs
,
EmbedsPrompt
,
from
.data
import
(
DecoderOnlyInputs
,
EmbedsInputs
,
EmbedsPrompt
,
EncoderDecoderInputs
,
ProcessorInputs
,
PromptType
,
EncoderDecoderInputs
,
ProcessorInputs
,
PromptType
,
...
@@ -31,7 +29,7 @@ class InputPreprocessor:
...
@@ -31,7 +29,7 @@ class InputPreprocessor:
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
tokenizer
:
Optional
[
Tokenizer
Group
],
tokenizer
:
Optional
[
Any
Tokenizer
],
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_processor_cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
mm_processor_cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -42,32 +40,28 @@ class InputPreprocessor:
...
@@ -42,32 +40,28 @@ class InputPreprocessor:
self
.
mm_registry
=
mm_registry
self
.
mm_registry
=
mm_registry
self
.
mm_processor_cache
=
mm_processor_cache
self
.
mm_processor_cache
=
mm_processor_cache
def
get_tokenizer
_group
(
self
)
->
Tokenizer
Group
:
def
get_tokenizer
(
self
)
->
Any
Tokenizer
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"You cannot pass text prompts when "
raise
ValueError
(
"You cannot pass text prompts when "
"`skip_tokenizer_init` is True"
)
"`skip_tokenizer_init` is True"
)
return
self
.
tokenizer
return
self
.
tokenizer
def
get_bos_token_id
(
self
,
def
get_bos_token_id
(
self
)
->
Optional
[
int
]:
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
logger
.
warning
(
"Using None for BOS token id because tokenizer "
logger
.
warning
(
"Using None for BOS token id because tokenizer "
"is not initialized"
)
"is not initialized"
)
return
None
return
None
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
bos_token_id
return
self
.
tokenizer
.
bos_token_id
def
get_eos_token_id
(
self
,
def
get_eos_token_id
(
self
)
->
Optional
[
int
]:
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
logger
.
warning
(
"Using None for EOS token id because tokenizer "
logger
.
warning
(
"Using None for EOS token id because tokenizer "
"is not initialized"
)
"is not initialized"
)
return
None
return
None
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
return
self
.
tokenizer
.
eos_token_id
def
get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
def
get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
"""
"""
...
@@ -190,14 +184,13 @@ class InputPreprocessor:
...
@@ -190,14 +184,13 @@ class InputPreprocessor:
def
_tokenize_prompt
(
def
_tokenize_prompt
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
lora_request
:
Optional
[
LoRARequest
],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
int
]:
)
->
list
[
int
]:
"""
"""
Apply the model's tokenizer to a text prompt, returning the
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
corresponding token IDs.
"""
"""
tokenizer
=
self
.
get_tokenizer
_group
()
tokenizer
=
self
.
get_tokenizer
()
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
encoder_config
=
self
.
model_config
.
encoder_config
encoder_config
=
self
.
model_config
.
encoder_config
...
@@ -205,50 +198,39 @@ class InputPreprocessor:
...
@@ -205,50 +198,39 @@ class InputPreprocessor:
if
encoder_config
and
encoder_config
.
get
(
"do_lower_case"
,
False
):
if
encoder_config
and
encoder_config
.
get
(
"do_lower_case"
,
False
):
prompt
=
prompt
.
lower
()
prompt
=
prompt
.
lower
()
return
tokenizer
.
encode
(
prompt
=
prompt
,
return
tokenizer
.
encode
(
prompt
,
**
tokenization_kwargs
)
lora_request
=
lora_request
,
**
tokenization_kwargs
)
async
def
_tokenize_prompt_async
(
async
def
_tokenize_prompt_async
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
lora_request
:
Optional
[
LoRARequest
],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
int
]:
)
->
list
[
int
]:
"""
"""
Async version of
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
"""
tokenizer
=
self
.
get_tokenizer
_group
()
tokenizer
=
self
.
get_tokenizer
()
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
return
await
tokenizer
.
encode_async
(
prompt
=
prompt
,
return
tokenizer
.
encode
(
prompt
,
**
tokenization_kwargs
)
lora_request
=
lora_request
,
**
tokenization_kwargs
)
def
_get_mm_tokenizer
(
def
_get_mm_tokenizer
(
self
)
->
AnyTokenizer
:
self
,
lora_request
:
Optional
[
LoRARequest
],
)
->
AnyTokenizer
:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
# while using also multi-modal input
if
not
self
.
tokenizer
:
if
not
self
.
tokenizer
:
return
cast
(
AnyTokenizer
,
object
())
# Dummy
return
cast
(
AnyTokenizer
,
object
())
# Dummy
tokenizer
_group
=
self
.
get_tokenizer
_group
()
tokenizer
=
self
.
get_tokenizer
()
return
tokenizer
_group
.
get_lora_tokenizer
(
lora_request
)
return
tokenizer
async
def
_get_mm_tokenizer_async
(
async
def
_get_mm_tokenizer_async
(
self
)
->
AnyTokenizer
:
self
,
lora_request
:
Optional
[
LoRARequest
],
)
->
AnyTokenizer
:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
# while using also multi-modal input
if
not
self
.
tokenizer
:
if
not
self
.
tokenizer
:
return
cast
(
AnyTokenizer
,
object
())
# Dummy
return
cast
(
AnyTokenizer
,
object
())
# Dummy
tokenizer
_group
=
self
.
get_tokenizer
_group
()
tokenizer
=
self
.
get_tokenizer
()
return
await
tokenizer
_group
.
get_lora_tokenizer_async
(
lora_request
)
return
tokenizer
def
_process_multimodal
(
def
_process_multimodal
(
self
,
self
,
...
@@ -256,7 +238,6 @@ class InputPreprocessor:
...
@@ -256,7 +238,6 @@ class InputPreprocessor:
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
mm_processor_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
mm_processor_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
...
@@ -264,7 +245,7 @@ class InputPreprocessor:
...
@@ -264,7 +245,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
returning the corresponding token IDs and metadata.
"""
"""
tokenizer
=
self
.
_get_mm_tokenizer
(
lora_request
)
tokenizer
=
self
.
_get_mm_tokenizer
()
mm_processor
=
self
.
mm_registry
.
create_processor
(
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
self
.
model_config
,
...
@@ -299,7 +280,6 @@ class InputPreprocessor:
...
@@ -299,7 +280,6 @@ class InputPreprocessor:
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
mm_processor_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
mm_processor_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
...
@@ -307,7 +287,7 @@ class InputPreprocessor:
...
@@ -307,7 +287,7 @@ class InputPreprocessor:
Async version of
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
"""
tokenizer
=
await
self
.
_get_mm_tokenizer_async
(
lora_request
)
tokenizer
=
await
self
.
_get_mm_tokenizer_async
()
mm_processor
=
self
.
mm_registry
.
create_processor
(
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
self
.
model_config
,
...
@@ -386,7 +366,6 @@ class InputPreprocessor:
...
@@ -386,7 +366,6 @@ class InputPreprocessor:
self
,
self
,
parsed_content
:
TokensPrompt
,
parsed_content
:
TokensPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
...
@@ -400,7 +379,6 @@ class InputPreprocessor:
...
@@ -400,7 +379,6 @@ class InputPreprocessor:
multi_modal_data
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
else
:
else
:
...
@@ -415,7 +393,6 @@ class InputPreprocessor:
...
@@ -415,7 +393,6 @@ class InputPreprocessor:
self
,
self
,
parsed_content
:
TokensPrompt
,
parsed_content
:
TokensPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
...
@@ -429,7 +406,6 @@ class InputPreprocessor:
...
@@ -429,7 +406,6 @@ class InputPreprocessor:
multi_modal_data
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
else
:
else
:
...
@@ -444,7 +420,6 @@ class InputPreprocessor:
...
@@ -444,7 +420,6 @@ class InputPreprocessor:
self
,
self
,
parsed_content
:
TextPrompt
,
parsed_content
:
TextPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
...
@@ -457,13 +432,11 @@ class InputPreprocessor:
...
@@ -457,13 +432,11 @@ class InputPreprocessor:
multi_modal_data
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
else
:
else
:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
prompt_text
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
inputs
=
token_inputs
(
inputs
=
token_inputs
(
...
@@ -480,7 +453,6 @@ class InputPreprocessor:
...
@@ -480,7 +453,6 @@ class InputPreprocessor:
self
,
self
,
parsed_content
:
TextPrompt
,
parsed_content
:
TextPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
...
@@ -493,13 +465,11 @@ class InputPreprocessor:
...
@@ -493,13 +465,11 @@ class InputPreprocessor:
multi_modal_data
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
else
:
else
:
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
prompt_text
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
inputs
=
token_inputs
(
inputs
=
token_inputs
(
...
@@ -516,7 +486,6 @@ class InputPreprocessor:
...
@@ -516,7 +486,6 @@ class InputPreprocessor:
self
,
self
,
prompt
:
SingletonPrompt
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
SingletonInputs
:
)
->
SingletonInputs
:
...
@@ -526,7 +495,6 @@ class InputPreprocessor:
...
@@ -526,7 +495,6 @@ class InputPreprocessor:
Arguments:
Arguments:
* prompt: single encoder or decoder input prompt
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
Returns:
...
@@ -539,21 +507,18 @@ class InputPreprocessor:
...
@@ -539,21 +507,18 @@ class InputPreprocessor:
if
parsed
[
"type"
]
==
"tokens"
:
if
parsed
[
"type"
]
==
"tokens"
:
return
self
.
_process_tokens
(
return
self
.
_process_tokens
(
parsed
[
"content"
],
parsed
[
"content"
],
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
if
parsed
[
"type"
]
==
"text"
:
if
parsed
[
"type"
]
==
"text"
:
return
self
.
_process_text
(
return
self
.
_process_text
(
parsed
[
"content"
],
parsed
[
"content"
],
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
if
parsed
[
"type"
]
==
"str"
:
if
parsed
[
"type"
]
==
"str"
:
return
self
.
_process_text
(
return
self
.
_process_text
(
TextPrompt
(
prompt
=
parsed
[
"content"
]),
TextPrompt
(
prompt
=
parsed
[
"content"
]),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
...
@@ -563,7 +528,6 @@ class InputPreprocessor:
...
@@ -563,7 +528,6 @@ class InputPreprocessor:
self
,
self
,
prompt
:
SingletonPrompt
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
SingletonInputs
:
)
->
SingletonInputs
:
...
@@ -578,21 +542,18 @@ class InputPreprocessor:
...
@@ -578,21 +542,18 @@ class InputPreprocessor:
if
parsed
[
"type"
]
==
"tokens"
:
if
parsed
[
"type"
]
==
"tokens"
:
return
await
self
.
_process_tokens_async
(
return
await
self
.
_process_tokens_async
(
parsed
[
"content"
],
parsed
[
"content"
],
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
if
parsed
[
"type"
]
==
"text"
:
if
parsed
[
"type"
]
==
"text"
:
return
await
self
.
_process_text_async
(
return
await
self
.
_process_text_async
(
parsed
[
"content"
],
parsed
[
"content"
],
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
if
parsed
[
"type"
]
==
"str"
:
if
parsed
[
"type"
]
==
"str"
:
return
await
self
.
_process_text_async
(
return
await
self
.
_process_text_async
(
TextPrompt
(
prompt
=
parsed
[
"content"
]),
TextPrompt
(
prompt
=
parsed
[
"content"
]),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
...
@@ -844,7 +805,6 @@ class InputPreprocessor:
...
@@ -844,7 +805,6 @@ class InputPreprocessor:
self
,
self
,
prompt
:
SingletonPrompt
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
...
@@ -856,7 +816,6 @@ class InputPreprocessor:
...
@@ -856,7 +816,6 @@ class InputPreprocessor:
Arguments:
Arguments:
* prompt: input prompt
* prompt: input prompt
* lora_request
Returns:
Returns:
...
@@ -866,7 +825,6 @@ class InputPreprocessor:
...
@@ -866,7 +825,6 @@ class InputPreprocessor:
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
...
@@ -876,7 +834,6 @@ class InputPreprocessor:
...
@@ -876,7 +834,6 @@ class InputPreprocessor:
self
,
self
,
prompt
:
SingletonPrompt
,
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
...
@@ -887,7 +844,6 @@ class InputPreprocessor:
...
@@ -887,7 +844,6 @@ class InputPreprocessor:
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
...
@@ -897,7 +853,6 @@ class InputPreprocessor:
...
@@ -897,7 +853,6 @@ class InputPreprocessor:
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
...
@@ -919,7 +874,6 @@ class InputPreprocessor:
...
@@ -919,7 +874,6 @@ class InputPreprocessor:
return
self
.
_process_decoder_only_prompt
(
return
self
.
_process_decoder_only_prompt
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
...
@@ -927,7 +881,6 @@ class InputPreprocessor:
...
@@ -927,7 +881,6 @@ class InputPreprocessor:
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
*
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
...
@@ -952,7 +905,6 @@ class InputPreprocessor:
...
@@ -952,7 +905,6 @@ class InputPreprocessor:
return
await
self
.
_process_decoder_only_prompt_async
(
return
await
self
.
_process_decoder_only_prompt_async
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
...
...
vllm/transformers_utils/detokenizer.py
View file @
6c47f6bf
...
@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
...
@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from
.detokenizer_utils
import
(
convert_prompt_ids_to_tokens
,
from
.detokenizer_utils
import
(
convert_prompt_ids_to_tokens
,
detokenize_incrementally
)
detokenize_incrementally
)
from
.tokenizer
import
AnyTokenizer
from
.tokenizer
import
AnyTokenizer
from
.tokenizer_group
import
TokenizerGroup
class
Detokenizer
:
class
Detokenizer
:
"""Provides methods to decode the output of a model into text."""
"""Provides methods to decode the output of a model into text."""
def
__init__
(
self
,
tokenizer_group
:
TokenizerGroup
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
self
.
tokenizer_group
=
tokenizer_group
self
.
tokenizer
=
tokenizer
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
"""Returns the HF tokenizer to use for a given sequence."""
return
self
.
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
decode_prompt_logprobs_inplace
(
self
,
seq_group
:
SequenceGroup
,
def
decode_prompt_logprobs_inplace
(
self
,
seq_group
:
SequenceGroup
,
prompt_logprobs
:
list
[
Optional
[
dict
[
prompt_logprobs
:
list
[
Optional
[
dict
[
...
@@ -32,9 +27,9 @@ class Detokenizer:
...
@@ -32,9 +27,9 @@ class Detokenizer:
Args:
Args:
seq_group: The sequence group to decode.
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
relative to the start of the sequence (for chunked prefill).
Returns:
Returns:
The prompt logprobs with the decoded tokens.
The prompt logprobs with the decoded tokens.
"""
"""
...
@@ -46,7 +41,6 @@ class Detokenizer:
...
@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token.
# Only prompt, without the generated token.
all_token_ids
=
seq
.
get_token_ids
()
all_token_ids
=
seq
.
get_token_ids
()
prompt_token_ids
=
all_token_ids
[:
-
1
]
prompt_token_ids
=
all_token_ids
[:
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
prefix_offset
=
0
prefix_offset
=
0
read_offset
=
0
read_offset
=
0
next_iter_prefix_offset
=
0
next_iter_prefix_offset
=
0
...
@@ -70,7 +64,7 @@ class Detokenizer:
...
@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids
[:
token_position
]
+
[
token_id
])
prompt_token_ids
[:
token_position
]
+
[
token_id
])
(
new_tokens
,
new_text
,
new_prefix_offset
,
(
new_tokens
,
new_text
,
new_prefix_offset
,
new_read_offset
)
=
detokenize_incrementally
(
new_read_offset
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
prompt_token_ids_with_token
,
all_input_ids
=
prompt_token_ids_with_token
,
prev_tokens
=
prev_tokens
,
prev_tokens
=
prev_tokens
,
prefix_offset
=
prefix_offset
,
prefix_offset
=
prefix_offset
,
...
@@ -111,7 +105,6 @@ class Detokenizer:
...
@@ -111,7 +105,6 @@ class Detokenizer:
"""
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
# Convert prompt token IDs to tokens if necessary.
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# Do it here so that we don't have to repeat this
...
@@ -119,14 +112,14 @@ class Detokenizer:
...
@@ -119,14 +112,14 @@ class Detokenizer:
if
seq
.
tokens
is
None
:
if
seq
.
tokens
is
None
:
(
seq
.
tokens
,
seq
.
prefix_offset
,
(
seq
.
tokens
,
seq
.
prefix_offset
,
seq
.
read_offset
)
=
convert_prompt_ids_to_tokens
(
seq
.
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
tokenizer
=
self
.
tokenizer
,
prompt_ids
=
all_input_ids
[:
-
1
],
prompt_ids
=
all_input_ids
[:
-
1
],
skip_special_tokens
=
prms
.
skip_special_tokens
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
)
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
...
@@ -150,7 +143,7 @@ class Detokenizer:
...
@@ -150,7 +143,7 @@ class Detokenizer:
and
token_id
!=
VLLM_INVALID_TOKEN_ID
):
and
token_id
!=
VLLM_INVALID_TOKEN_ID
):
all_input_ids_with_logprob
=
previous_tokens
+
[
token_id
]
all_input_ids_with_logprob
=
previous_tokens
+
[
token_id
]
(
_
,
new_text
,
_
,
_
)
=
detokenize_incrementally
(
(
_
,
new_text
,
_
,
_
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
tokenizer
=
self
.
tokenizer
,
all_input_ids
=
all_input_ids_with_logprob
,
all_input_ids
=
all_input_ids_with_logprob
,
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment