Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
355f6634
Unverified
Commit
355f6634
authored
Mar 28, 2025
by
Cyrus Leung
Committed by
GitHub
Mar 27, 2025
Browse files
[V1] Remove legacy input registry (#15673)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
8693e47e
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
132 additions
and
153 deletions
+132
-153
tests/models/multimodal/processing/test_h2ovl.py
tests/models/multimodal/processing/test_h2ovl.py
+1
-6
tests/models/multimodal/processing/test_idefics3.py
tests/models/multimodal/processing/test_idefics3.py
+1
-6
tests/models/multimodal/processing/test_internvl.py
tests/models/multimodal/processing/test_internvl.py
+1
-6
tests/models/multimodal/processing/test_llava_next.py
tests/models/multimodal/processing/test_llava_next.py
+3
-13
tests/models/multimodal/processing/test_llava_onevision.py
tests/models/multimodal/processing/test_llava_onevision.py
+3
-13
tests/models/multimodal/processing/test_phi3v.py
tests/models/multimodal/processing/test_phi3v.py
+1
-6
tests/models/multimodal/processing/test_qwen2_vl.py
tests/models/multimodal/processing/test_qwen2_vl.py
+2
-6
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+4
-14
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+6
-6
vllm/inputs/registry.py
vllm/inputs/registry.py
+17
-8
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+24
-31
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+53
-10
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+4
-3
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-3
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+10
-12
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-8
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+0
-2
No files found.
tests/models/multimodal/processing/test_h2ovl.py
View file @
355f6634
...
...
@@ -10,7 +10,6 @@ from transformers import PretrainedConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
...
...
@@ -156,11 +155,7 @@ def test_processor_override(
mm_processor_kwargs
=
mm_processor_kwargs
if
kwargs_on_init
else
None
,
limit_mm_per_prompt
=
{
"image"
:
len
(
size_factors
)},
)
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
hf_processor_mm_kwargs
=
{}
if
kwargs_on_init
else
mm_processor_kwargs
min_num
=
min_dynamic_patch
if
dynamic_image_size
else
1
...
...
tests/models/multimodal/processing/test_idefics3.py
View file @
355f6634
...
...
@@ -4,7 +4,6 @@ import pytest
from
transformers
import
Idefics3Config
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
...
...
@@ -38,11 +37,7 @@ def test_processor_override(
mm_processor_kwargs
=
mm_processor_kwargs
if
kwargs_on_init
else
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
hf_processor_mm_kwargs
=
{}
if
kwargs_on_init
else
mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
...
...
tests/models/multimodal/processing/test_internvl.py
View file @
355f6634
...
...
@@ -10,7 +10,6 @@ from transformers import PretrainedConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
...
...
@@ -113,11 +112,7 @@ def test_processor_override(
mm_processor_kwargs
=
mm_processor_kwargs
if
kwargs_on_init
else
None
,
limit_mm_per_prompt
=
{
"image"
:
len
(
size_factors
)},
)
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
hf_processor_mm_kwargs
=
{}
if
kwargs_on_init
else
mm_processor_kwargs
min_num
=
min_dynamic_patch
if
dynamic_image_size
else
1
...
...
tests/models/multimodal/processing/test_llava_next.py
View file @
355f6634
...
...
@@ -10,7 +10,6 @@ from pqdm.threads import pqdm
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.parse
import
ImageSize
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
...utils
import
build_model_context
...
...
@@ -40,10 +39,7 @@ def test_processor_max_tokens(model_id):
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
info
=
processor
.
info
seen_aspect_ratios
=
set
[
float
]()
...
...
@@ -139,10 +135,7 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
image_ratios
=
[(
171
,
152
),
(
184
,
161
),
(
198
,
176
),
(
333
,
296
),
(
369
,
328
),
(
488
,
183
),
(
2560
,
1669
)]
...
...
@@ -168,10 +161,7 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
seen_aspect_ratios
=
set
[
float
]()
image_sizes
=
list
[
ImageSize
]()
...
...
tests/models/multimodal/processing/test_llava_onevision.py
View file @
355f6634
...
...
@@ -10,7 +10,6 @@ from pqdm.threads import pqdm
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.parse
import
ImageSize
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
...utils
import
build_model_context
...
...
@@ -41,10 +40,7 @@ def test_processor_max_tokens(model_id):
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
info
=
processor
.
info
seen_aspect_ratios
=
set
[
float
]()
...
...
@@ -139,10 +135,7 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
image_ratios
=
[(
171
,
152
),
(
184
,
161
),
(
198
,
176
),
(
333
,
296
),
(
369
,
328
),
(
488
,
183
),
(
2560
,
1669
)]
...
...
@@ -169,10 +162,7 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
seen_aspect_ratios
=
set
[
float
]()
image_sizes
=
list
[
ImageSize
]()
...
...
tests/models/multimodal/processing/test_phi3v.py
View file @
355f6634
...
...
@@ -3,7 +3,6 @@
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
...
...
@@ -39,11 +38,7 @@ def test_processor_override(
mm_processor_kwargs
=
mm_processor_kwargs
if
kwargs_on_init
else
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
hf_processor_mm_kwargs
=
{}
if
kwargs_on_init
else
mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
...
...
tests/models/multimodal/processing/test_qwen2_vl.py
View file @
355f6634
...
...
@@ -3,7 +3,6 @@
import
pytest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
....conftest
import
_ImageAssets
from
...utils
import
build_model_context
...
...
@@ -34,11 +33,8 @@ def test_processor_override(
mm_processor_kwargs
=
mm_processor_kwargs
if
kwargs_on_init
else
None
,
limit_mm_per_prompt
=
{
"image"
:
num_imgs
},
)
tokenizer
=
cached_tokenizer_from_config
(
ctx
.
model_config
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
,
tokenizer
=
tokenizer
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
tokenizer
=
processor
.
info
.
get_tokenizer
()
hf_processor_mm_kwargs
=
{}
if
kwargs_on_init
else
mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
...
...
tests/multimodal/test_processing.py
View file @
355f6634
...
...
@@ -28,8 +28,7 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
replace_token_matches
)
# yapf: enable
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
cached_tokenizer_from_config
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
full_groupby
from
.utils
import
random_image
...
...
@@ -955,10 +954,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
)
profiler
=
MultiModalProfiler
(
processor
)
mock_supported_mm_limits
=
MagicMock
(
return_value
=
{
"image"
:
num_supported
})
...
...
@@ -994,10 +990,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
)
rng
=
np
.
random
.
RandomState
(
0
)
image
=
random_image
(
rng
,
min_wh
=
128
,
max_wh
=
256
)
...
...
@@ -1066,10 +1059,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
revision
=
None
,
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
tokenizer
=
cached_tokenizer_from_config
(
model_config
),
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
)
orig_get_hf_processor
=
processor
.
info
.
get_hf_processor
def
get_hf_processor
(
self
,
**
kwargs
):
...
...
vllm/inputs/preprocess.py
View file @
355f6634
...
...
@@ -261,13 +261,13 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal
# input.
if
not
self
.
tokenizer
:
tokenizer
=
None
tokenizer
=
object
()
# Dummy
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
)
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
...
...
@@ -288,14 +288,14 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal
# input.
if
not
self
.
tokenizer
:
tokenizer
=
None
tokenizer
=
object
()
# Dummy
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
)
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
...
...
vllm/inputs/registry.py
View file @
355f6634
...
...
@@ -13,8 +13,7 @@ from typing_extensions import TypeVar, assert_never
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
cached_tokenizer_from_config
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
resolve_mm_processor_kwargs
)
...
...
@@ -329,17 +328,27 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
vllm.sequence
import
SequenceData
if
mm_registry
.
has_processor
(
model_config
):
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
,
disable_cache
=
True
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data_factory
=
(
profiler
.
get_encoder_dummy_data
if
is_encoder_data
else
profiler
.
get_decoder_dummy_data
)
dummy_data
=
dummy_data_factory
(
seq_len
)
dummy_data_v1
=
(
profiler
.
get_encoder_dummy_data
(
seq_len
)
if
is_encoder_data
else
profiler
.
get_decoder_dummy_data
(
seq_len
))
_seq_data
=
SequenceData
.
from_seqs
(
dummy_data_v1
.
prompt_token_ids
)
# type: ignore[attr-defined]
dummy_data
=
DummyData
(
seq_data
=
_seq_data
,
multi_modal_data
=
getattr
(
dummy_data_v1
,
"multi_modal_data"
,
None
),
multi_modal_placeholders
=
getattr
(
dummy_data_v1
,
"multi_modal_placeholders"
,
None
),
)
else
:
model_cls
,
_
=
get_model_architecture
(
model_config
)
if
is_encoder_data
:
...
...
vllm/multimodal/profiling.py
View file @
355f6634
...
...
@@ -3,18 +3,18 @@
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
,
field
from
typing
import
Generic
,
TypeVar
,
cast
from
typing
import
Generic
,
NamedTuple
,
TypeVar
,
cast
import
numpy
as
np
import
numpy.typing
as
npt
from
PIL
import
Image
import
vllm.envs
as
envs
from
vllm.inputs
import
DummyData
from
vllm.logger
import
init_logger
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
MultiModalInputs
,
MultiModalKwargs
,
MultiModalPlaceholderDict
)
from
.processing
import
BaseMultiModalProcessor
,
BaseProcessingInfo
logger
=
init_logger
(
__name__
)
...
...
@@ -31,6 +31,20 @@ class ProcessorInputs:
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
]
=
field
(
default_factory
=
dict
)
class
DummyEncoderData
(
NamedTuple
):
"""Dummy data used for profiling."""
prompt_token_ids
:
list
[
int
]
class
DummyDecoderData
(
NamedTuple
):
"""Dummy data used for profiling."""
prompt_token_ids
:
list
[
int
]
multi_modal_data
:
MultiModalKwargs
multi_modal_placeholders
:
MultiModalPlaceholderDict
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
...
...
@@ -179,13 +193,7 @@ class MultiModalProfiler(Generic[_I]):
"tokens."
)
return
mm_inputs
,
total_placeholders_by_modality
def
get_encoder_dummy_data
(
self
,
seq_len
:
int
,
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
def
get_encoder_dummy_data
(
self
,
seq_len
:
int
)
->
DummyEncoderData
:
mm_inputs
,
_
=
self
.
get_and_validate_mm_inputs
(
seq_len
)
mm_inputs
=
cast
(
MultiModalEncDecInputs
,
mm_inputs
)
...
...
@@ -197,19 +205,9 @@ class MultiModalProfiler(Generic[_I]):
num_tokens_to_pad
=
max
(
total_len
,
seq_len
)
-
total_len
encoder_prompt_token_ids
.
extend
([
0
]
*
num_tokens_to_pad
)
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
encoder_prompt_token_ids
),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
def
get_decoder_dummy_data
(
self
,
seq_len
:
int
,
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
return
DummyEncoderData
(
encoder_prompt_token_ids
)
def
get_decoder_dummy_data
(
self
,
seq_len
:
int
)
->
DummyDecoderData
:
(
mm_inputs
,
total_placeholders_by_modality
)
=
self
.
get_and_validate_mm_inputs
(
seq_len
)
...
...
@@ -231,16 +229,11 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
prompt_token_ids
.
extend
([
0
]
*
(
seq_len
-
len
(
prompt_token_ids
)))
if
total_len
<
seq_len
:
prompt_token_ids
.
extend
([
0
]
*
(
seq_len
-
total_len
))
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
)
,
return
DummyD
ecoderD
ata
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
],
multi_modal_placeholders
=
mm_inputs
[
"mm_placeholders"
],
)
vllm/multimodal/registry.py
View file @
355f6634
...
...
@@ -21,7 +21,8 @@ from .image import ImagePlugin
from
.inputs
import
MultiModalDataDict
,
MultiModalKwargs
,
NestedTensors
from
.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
)
from
.profiling
import
BaseDummyInputsBuilder
,
MultiModalProfiler
from
.profiling
import
(
BaseDummyInputsBuilder
,
DummyDecoderData
,
DummyEncoderData
,
MultiModalProfiler
)
from
.video
import
VideoPlugin
if
TYPE_CHECKING
:
...
...
@@ -256,10 +257,7 @@ class MultiModalRegistry:
on underlying model configuration.
"""
if
self
.
has_processor
(
model_config
):
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
processor
=
self
.
create_processor
(
model_config
,
tokenizer
,
disable_cache
=
True
)
processor
=
self
.
create_processor
(
model_config
,
disable_cache
=
True
)
seq_len
=
model_config
.
max_model_len
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
)
return
processor
.
info
.
get_mm_max_tokens_per_item
(
...
...
@@ -373,10 +371,7 @@ class MultiModalRegistry:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
if
self
.
has_processor
(
model_config
):
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
processor
=
self
.
create_processor
(
model_config
,
tokenizer
,
disable_cache
=
True
)
processor
=
self
.
create_processor
(
model_config
,
disable_cache
=
True
)
profiler
=
MultiModalProfiler
(
processor
)
return
profiler
.
get_mm_limits
()
...
...
@@ -436,8 +431,8 @@ class MultiModalRegistry:
def
create_processor
(
self
,
model_config
:
"ModelConfig"
,
tokenizer
:
AnyTokenizer
,
*
,
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
,
disable_cache
:
Optional
[
bool
]
=
None
,
)
->
BaseMultiModalProcessor
[
BaseProcessingInfo
]:
"""
...
...
@@ -446,6 +441,8 @@ class MultiModalRegistry:
See also:
:ref:`mm-processing`
"""
if
tokenizer
is
None
:
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
if
disable_cache
is
None
:
disable_cache
=
model_config
.
disable_mm_preprocessor_cache
...
...
@@ -456,3 +453,49 @@ class MultiModalRegistry:
cache
=
None
if
disable_cache
else
self
.
_processing_cache
return
factories
.
build_processor
(
ctx
,
cache
=
cache
)
def
get_decoder_dummy_data
(
self
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
)
->
DummyDecoderData
:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor
=
self
.
create_processor
(
model_config
,
disable_cache
=
True
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_decoder_dummy_data
(
seq_len
)
# Having more tokens is over-conservative but otherwise fine
token_ids
=
dummy_data
.
prompt_token_ids
if
len
(
token_ids
)
<
seq_len
:
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
token_ids
)
}
tokens instead."
)
return
dummy_data
def
get_encoder_dummy_data
(
self
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
)
->
DummyEncoderData
:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor
=
self
.
create_processor
(
model_config
,
disable_cache
=
True
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_encoder_dummy_data
(
seq_len
)
# Having more tokens is over-conservative but otherwise fine
token_ids
=
dummy_data
.
prompt_token_ids
if
len
(
token_ids
)
<
seq_len
:
logger
.
warning_once
(
f
"Expected at least
{
seq_len
}
dummy encoder tokens for "
f
"profiling, but found
{
len
(
token_ids
)
}
tokens instead."
)
return
dummy_data
vllm/v1/engine/async_llm.py
View file @
355f6634
...
...
@@ -14,10 +14,11 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.protocol
import
EngineClient
from
vllm.envs
import
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
...
@@ -48,7 +49,7 @@ class AsyncLLM(EngineClient):
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
input
_registry
:
Input
Registry
=
INPUT
_REGISTRY
,
mm
_registry
:
MultiModal
Registry
=
MULTIMODAL
_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
...
...
@@ -90,7 +91,7 @@ class AsyncLLM(EngineClient):
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
tokenizer
=
self
.
tokenizer
,
input
_registry
=
input
_registry
,
mm
_registry
=
mm
_registry
,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
...
...
vllm/v1/engine/llm_engine.py
View file @
355f6634
...
...
@@ -11,7 +11,7 @@ from vllm.config import ParallelConfig, VllmConfig
from
vllm.distributed
import
stateless_destroy_torch_distributed_process_group
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
,
PromptType
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
...
...
@@ -44,7 +44,6 @@ class LLMEngine:
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
multiprocess_mode
:
bool
=
False
,
...
...
@@ -80,7 +79,6 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
mm_registry
=
mm_registry
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
...
...
vllm/v1/engine/processor.py
View file @
355f6634
...
...
@@ -5,8 +5,7 @@ from collections.abc import Mapping
from
typing
import
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
PromptType
,
SingletonInputsAdapter
)
from
vllm.inputs
import
ProcessorInputs
,
PromptType
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
...
...
@@ -31,7 +30,6 @@ class Processor:
self
,
vllm_config
:
VllmConfig
,
tokenizer
:
BaseTokenizerGroup
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
...
...
@@ -210,7 +208,6 @@ class Processor:
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
decoder_inputs
=
SingletonInputsAdapter
(
decoder_inputs
)
# TODO: Impl encoder-decoder
if
encoder_inputs
is
not
None
:
...
...
@@ -221,8 +218,9 @@ class Processor:
sampling_params
=
params
.
clone
()
# If unset max tokens, then generate up to the max_model_len.
if
sampling_params
.
max_tokens
is
None
:
sampling_params
.
max_tokens
=
(
self
.
model_config
.
max_model_len
-
len
(
decoder_inputs
.
prompt_token_ids
))
sampling_params
.
max_tokens
=
(
self
.
model_config
.
max_model_len
-
len
(
decoder_inputs
[
"prompt_token_ids"
]))
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
,
eos_token_id
)
sampling_params
.
update_from_tokenizer
(
...
...
@@ -232,8 +230,8 @@ class Processor:
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
(
decoder_
mm_
inputs
:
=
decoder_inputs
.
multi
_
modal
_data
)
:
assert
isinstance
(
decoder_mm_inputs
,
MultiModalK
wargs
)
if
decoder_inputs
[
"type"
]
==
"
multimodal
"
:
decoder_mm_inputs
=
decoder_inputs
[
"mm_k
wargs
"
]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
...
...
@@ -254,8 +252,8 @@ class Processor:
sorted_mm_positions
,
sorted_mm_hashes
,
)
=
merge_and_sort_multimodal_metadata
(
decoder_inputs
.
multi_modal
_placeholders
,
decoder_inputs
.
multi_modal
_hashes
if
self
.
use_hash
else
None
,
decoder_inputs
[
"mm
_placeholders
"
]
,
decoder_inputs
[
"mm
_hashes
"
]
if
self
.
use_hash
else
None
,
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
...
...
@@ -281,8 +279,8 @@ class Processor:
return
EngineCoreRequest
(
request_id
=
request_id
,
prompt
=
decoder_inputs
.
prompt
,
prompt_token_ids
=
decoder_inputs
.
prompt_token_ids
,
prompt
=
decoder_inputs
.
get
(
"
prompt
"
)
,
prompt_token_ids
=
decoder_inputs
[
"
prompt_token_ids
"
]
,
mm_inputs
=
sorted_mm_inputs
,
mm_hashes
=
sorted_mm_hashes
,
mm_placeholders
=
sorted_mm_positions
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
355f6634
...
...
@@ -15,7 +15,6 @@ from vllm.attention.layer import Attention
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
...
@@ -130,7 +129,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
...
...
@@ -1473,16 +1471,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_budget
,
max_num_mm_items
,
dummy_data_modality
)
# Create dummy batch of multimodal inputs.
dummy_request_data
=
self
.
input
_registry
.
dummy_data_for_profiling
(
dummy_request_data
=
self
.
mm
_registry
.
get_decoder_dummy_data
(
model_config
=
self
.
model_config
,
seq_len
=
self
.
max_num_tokens
,
mm_registry
=
self
.
mm_registry
,
)
dummy_mm_data
=
dummy_request_data
.
multi_modal_data
if
not
isinstance
(
dummy_mm_data
,
MultiModalKwargs
):
# TODO: Delete this check once input mapper is fully removed.
raise
RuntimeError
(
"Legacy input mapper is not supported in V1"
)
# Dummy data definition may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
355f6634
...
...
@@ -17,7 +17,6 @@ from vllm.attention.backends.abstract import AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
...
...
@@ -102,7 +101,6 @@ class TPUModelRunner:
self
.
hidden_size
=
model_config
.
get_hidden_size
()
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
# TODO: Support M-RoPE (e.g, Qwen2-VL)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment