Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b8c8ba1
Unverified
Commit
9b8c8ba1
authored
Sep 23, 2024
by
Alex Brooks
Committed by
GitHub
Sep 23, 2024
Browse files
[Core][Frontend] Support Passing Multimodal Processor Kwargs (#8657)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
d23679eb
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
589 additions
and
116 deletions
+589
-116
tests/engine/test_arg_utils.py
tests/engine/test_arg_utils.py
+21
-0
tests/models/decoder_only/vision_language/test_qwen.py
tests/models/decoder_only/vision_language/test_qwen.py
+1
-28
tests/models/utils.py
tests/models/utils.py
+35
-0
tests/multimodal/test_processor_kwargs.py
tests/multimodal/test_processor_kwargs.py
+339
-0
vllm/config.py
vllm/config.py
+5
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-0
vllm/inputs/registry.py
vllm/inputs/registry.py
+27
-11
vllm/multimodal/base.py
vllm/multimodal/base.py
+16
-3
vllm/multimodal/image.py
vllm/multimodal/image.py
+8
-2
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+9
-0
vllm/multimodal/video.py
vllm/multimodal/video.py
+7
-2
vllm/transformers_utils/image_processor.py
vllm/transformers_utils/image_processor.py
+0
-64
vllm/transformers_utils/processor.py
vllm/transformers_utils/processor.py
+61
-4
vllm/utils.py
vllm/utils.py
+48
-0
No files found.
tests/engine/test_arg_utils.py
View file @
9b8c8ba1
...
...
@@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected):
def
test_bad_nullable_kvs
(
arg
):
with
pytest
.
raises
(
ArgumentTypeError
):
nullable_kvs
(
arg
)
@
pytest
.
mark
.
parametrize
((
"arg"
,
"expected"
),
[
(
None
,
None
),
(
"{}"
,
{}),
(
'{"num_crops": 4}'
,
{
"num_crops"
:
4
}),
(
'{"foo": {"bar": "baz"}}'
,
{
"foo"
:
{
"bar"
:
"baz"
}
}),
])
def
test_mm_processor_kwargs_prompt_parser
(
arg
,
expected
):
parser
=
EngineArgs
.
add_cli_args
(
FlexibleArgumentParser
())
if
arg
is
None
:
args
=
parser
.
parse_args
([])
else
:
args
=
parser
.
parse_args
([
"--mm-processor-kwargs"
,
arg
])
assert
args
.
mm_processor_kwargs
==
expected
tests/models/decoder_only/vision_language/test_qwen.py
View file @
9b8c8ba1
...
...
@@ -5,14 +5,13 @@ import pytest
import
torch
from
PIL.Image
import
Image
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
,
LLMInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
,
rescale_image_size
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
ImageAsset
,
PromptImageInput
,
VllmRunner
,
_ImageAssets
)
from
...utils
import
check_logprobs_close
from
...utils
import
build_model_context
,
check_logprobs_close
text_only_models
=
[
"Qwen/Qwen-7B-Chat"
# Has no visual component
...
...
@@ -42,32 +41,6 @@ VIS_ENC_DIM = 4096
IMG_SIZE
=
448
def
build_model_context
(
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
Returns:
InputContext for the model being considered.
"""
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
model_config
=
ModelConfig
(
model_name
,
tokenizer_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
dtype
=
"float32"
,
seed
=
0
,
)
return
InputContext
(
model_config
)
@
pytest
.
fixture
()
def
input_mapper_for_qwen
():
# Lazy import to avoid initializing CUDA during test collection
...
...
tests/models/utils.py
View file @
9b8c8ba1
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.sequence
import
Logprob
,
PromptLogprobs
,
SampleLogprobs
TokensText
=
Tuple
[
List
[
int
],
str
]
...
...
@@ -240,3 +242,36 @@ def check_logprobs_close(
warnings
.
simplefilter
(
"always"
)
warnings
.
warn
(
fail_msg
,
stacklevel
=
2
)
def
build_model_context
(
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
,
mm_processor_kwargs
:
Optional
[
Dict
]
=
None
,
limit_mm_per_prompt
:
Optional
[
Dict
]
=
None
):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
mm_processor_kwargs: optional processor kwargs for to be leveraged
in the input processor, mapper, dummy data creation, etc.
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
"""
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
model_config
=
ModelConfig
(
model_name
,
tokenizer_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
dtype
=
"float32"
,
seed
=
0
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
return
InputContext
(
model_config
)
tests/multimodal/test_processor_kwargs.py
0 → 100644
View file @
9b8c8ba1
from
array
import
array
from
typing
import
Mapping
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.inputs
import
InputContext
,
LLMInputs
from
vllm.inputs.registry
import
InputRegistry
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
..models.utils
import
build_model_context
# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID
=
"facebook/opt-125m"
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID
=
"microsoft/Phi-3.5-vision-instruct"
# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS
=
4
NUM_CROPS_OVERRIDE
=
16
# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@
pytest
.
fixture
def
use_processor_mock
():
"""Patches the internal model input processor with an override callable."""
def
custom_processor
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
):
# For testing purposes, we don't worry about the llm inputs / return
# type validation, and just return the value of the kwarg that we
# clobber.
return
num_crops
with
patch
(
"vllm.inputs.registry.InputRegistry._get_model_input_processor"
,
return_value
=
custom_processor
):
yield
@
pytest
.
fixture
def
use_dummy_data_mock
():
"""Patches the internal model input processor with an override callable."""
def
custom_dummy_data_factory
(
self
,
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
*
,
num_crops
=
DEFAULT_NUM_CROPS
):
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
num_crops
))
return
seq_data
,
None
with
patch
(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory"
,
custom_dummy_data_factory
):
yield
# Lazy import to avoid CUDA reinitialization error
def
mm_model_cls
():
from
vllm.model_executor.models.phi3v
import
Phi3VForCausalLM
return
Phi3VForCausalLM
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops
=
lambda
ctx
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
:
num_crops
custom_mapper
=
lambda
ctx
,
data
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
:
{
"num_pixels"
:
torch
.
zeros
(
size
=
(
1
,
num_crops
+
1
,
3
,
336
,
336
))
}
### Test for default processor logic & mm_processor_kwargs wrapping
def
test_default_processor_is_a_noop
():
"""Ensure that by default, there is no processor override."""
dummy_registry
=
InputRegistry
()
ctx
=
build_model_context
(
DUMMY_MODEL_ID
)
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
proc_inputs
=
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
)
proc_outputs
=
processor
(
inputs
=
proc_inputs
)
assert
proc_inputs
is
proc_outputs
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
None
,
NUM_CROPS_OVERRIDE
])
def
test_processor_default_kwargs
(
use_processor_mock
,
num_crops
):
"""Ensure input processors can use processor kwargs."""
dummy_registry
=
InputRegistry
()
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
mm_processor_kwargs
=
None
if
num_crops
is
None
else
{
"num_crops"
:
num_crops
}
expected_num_crops
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
mm_processor_kwargs
=
mm_processor_kwargs
)
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
num_crops_val
=
processor
(
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
))
assert
num_crops_val
==
expected_num_crops
@
pytest
.
mark
.
parametrize
(
"mm_processor_kwargs"
,
[
# Not part of the signature
{
"does_not_exist"
:
100
},
# Part of the signature, not keyword only
{
"ctx"
:
"something bad"
}
])
def
test_processor_with_sad_kwarg_overrides
(
use_processor_mock
,
mm_processor_kwargs
):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry
=
InputRegistry
()
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
mm_processor_kwargs
=
mm_processor_kwargs
)
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
num_crops_val
=
processor
(
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
))
assert
num_crops_val
==
DEFAULT_NUM_CROPS
### Test overrides for the dummy data
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
None
,
NUM_CROPS_OVERRIDE
])
def
test_dummy_data_kwarg_overrides
(
use_dummy_data_mock
,
num_crops
):
"""Ensure dummy data factories can use processor kwargs."""
mm_processor_kwargs
=
None
if
num_crops
is
None
else
{
"num_crops"
:
num_crops
}
expected_seq_count
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
dummy_registry
=
InputRegistry
()
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
mm_processor_kwargs
=
mm_processor_kwargs
)
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data
,
_
=
dummy_registry
.
dummy_data_for_profiling
(
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
==
expected_seq_count
@
pytest
.
mark
.
parametrize
(
"mm_processor_kwargs"
,
[
# Not part of the signature
{
"does_not_exist"
:
100
},
# Part of the signature, not keyword only
{
"ctx"
:
"something bad"
}
])
def
test_dummy_data_with_sad_kwarg_overrides
(
use_dummy_data_mock
,
mm_processor_kwargs
):
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
dummy_registry
=
InputRegistry
()
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
mm_processor_kwargs
=
mm_processor_kwargs
)
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data
,
_
=
dummy_registry
.
dummy_data_for_profiling
(
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
==
DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
None
,
NUM_CROPS_OVERRIDE
])
def
test_max_tokens_kwarg_overrides
(
num_crops
):
"""Ensure max token calcs can use processor kwargs."""
mm_processor_kwargs
=
None
if
num_crops
is
None
else
{
"num_crops"
:
num_crops
}
expected_seq_count
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
with
patch
.
object
(
mm_registry
.
_get_plugin
(
"image"
),
"_max_mm_tokens"
,
{
mm_model_cls
():
get_num_crops
},
):
max_multimodal_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
ctx
.
model_config
)
assert
expected_seq_count
==
max_multimodal_tokens
@
pytest
.
mark
.
parametrize
(
"mm_processor_kwargs"
,
[
# Not part of the signature
{
"does_not_exist"
:
100
},
# Part of the signature, not keyword only
{
"ctx"
:
"something bad"
}
])
def
test_max_tokens_with_sad_kwarg_overrides
(
mm_processor_kwargs
):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# Similar before, but since these kwargs get filtered,
# we always get our default value back.
with
patch
.
object
(
mm_registry
.
_get_plugin
(
"image"
),
"_max_mm_tokens"
,
{
mm_model_cls
():
get_num_crops
},
):
max_multimodal_tokens
=
mm_registry
.
get_max_multimodal_tokens
(
ctx
.
model_config
)
assert
max_multimodal_tokens
==
DEFAULT_NUM_CROPS
### Test overrides for the mapper
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
DEFAULT_NUM_CROPS
,
NUM_CROPS_OVERRIDE
])
def
test_default_mapper_with_processer_kwargs
(
image_assets
,
num_crops
):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
{
"num_crops"
:
num_crops
},
limit_mm_per_prompt
=
{
"image"
:
1
})
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
image
}
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
)
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
num_crops
+
1
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
None
,
NUM_CROPS_OVERRIDE
])
def
test_custom_mapper_kwarg_overrides
(
image_assets
,
num_crops
):
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs
=
None
if
num_crops
is
None
else
{
"num_crops"
:
num_crops
}
expected_seq_count
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
image
}
with
patch
.
object
(
mm_registry
.
_get_plugin
(
"image"
),
"_default_input_mapper"
,
{
mm_model_cls
():
custom_mapper
},
):
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
)
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
expected_seq_count
+
1
@
pytest
.
mark
.
parametrize
(
"mm_processor_kwargs"
,
[
# Not part of the signature
{
"does_not_exist"
:
100
},
# Part of the signature, not keyword only
{
"ctx"
:
"something bad"
}
])
def
test_custom_mapper_with_sad_kwarg_overrides
(
image_assets
,
mm_processor_kwargs
):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
image
}
with
patch
.
object
(
mm_registry
.
_get_plugin
(
"image"
),
"_default_input_mapper"
,
{
mm_model_cls
():
custom_mapper
},
):
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
)
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
DEFAULT_NUM_CROPS
+
1
vllm/config.py
View file @
9b8c8ba1
...
...
@@ -122,6 +122,8 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
"""
def
__init__
(
self
,
...
...
@@ -150,7 +152,8 @@ class ModelConfig:
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
True
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
)
->
None
:
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
...
...
@@ -184,6 +187,7 @@ class ModelConfig:
self
.
model
,
revision
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
use_async_output_proc
=
use_async_output_proc
self
.
mm_processor_kwargs
=
mm_processor_kwargs
# Set enforce_eager to False if the value is unset.
if
self
.
enforce_eager
is
None
:
...
...
vllm/engine/arg_utils.py
View file @
9b8c8ba1
...
...
@@ -175,6 +175,7 @@ class EngineArgs:
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
...
...
@@ -513,6 +514,12 @@ class EngineArgs:
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'
))
parser
.
add_argument
(
'--mm-processor-kwargs'
,
default
=
None
,
type
=
json
.
loads
,
help
=
(
'Overrides for the multimodal input mapping/processing,'
'e.g., image processor. For example: {"num_crops": 4}.'
))
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
...
...
@@ -822,6 +829,7 @@ class EngineArgs:
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
,
config_format
=
self
.
config_format
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
)
def
create_load_config
(
self
)
->
LoadConfig
:
...
...
vllm/engine/llm_engine.py
View file @
9b8c8ba1
...
...
@@ -235,7 +235,7 @@ class LLMEngine:
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)"
,
"use_async_output_proc=%s
, mm_processor_kwargs=%s
)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -268,6 +268,7 @@ class LLMEngine:
scheduler_config
.
num_scheduler_steps
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
model_config
.
mm_processor_kwargs
,
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
...
...
vllm/entrypoints/llm.py
View file @
9b8c8ba1
...
...
@@ -134,6 +134,7 @@ class LLM:
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
None
:
'''
...
...
@@ -174,6 +175,7 @@ class LLM:
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
mm_processor_kwargs
=
mm_processor_kwargs
,
**
kwargs
,
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
...
...
vllm/inputs/registry.py
View file @
9b8c8ba1
...
...
@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_allowed_kwarg_only_overrides
from
.data
import
LLMInputs
...
...
@@ -68,12 +69,17 @@ class DummyDataFactory(Protocol):
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
**
mm_processor_kwargs
:
Any
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
...
...
...
@@ -152,6 +158,10 @@ class InputRegistry:
return
wrapper
def
_get_dummy_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
return
self
.
_dummy_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
dummy_data_for_profiling
(
self
,
model_config
:
"ModelConfig"
,
...
...
@@ -174,15 +184,15 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
dummy_factory
=
self
.
_dummy_factor
ies_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
dummy_factory
=
self
.
_
get_
dummy_
data_
factor
y
(
model_cls
)
mm_counts
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
)
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
)
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
)
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
**
mm_processor_kwargs
)
# Having more tokens is over-conservative but otherwise fine
num_tokens
=
seq_data
.
prompt_token_ids
...
...
@@ -229,6 +239,10 @@ class InputRegistry:
return
wrapper
def
_get_model_input_processor
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
return
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input_processor
)
def
process_input
(
self
,
model_config
:
"ModelConfig"
,
inputs
:
LLMInputs
)
->
LLMInputs
:
"""
...
...
@@ -243,15 +257,17 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
processor
=
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input
_processor
)
mm_
processor
_kwargs
=
get_allowed_kwarg_only_overrides
(
processor
,
overrides
=
model_config
.
mm
_processor
_kwargs
)
return
processor
(
InputContext
(
model_config
),
inputs
)
return
processor
(
InputContext
(
model_config
),
inputs
,
**
mm_processor_kwargs
)
def
create_input_processor
(
self
,
model_config
:
"ModelConfig"
):
"""
Create an input processor (see :meth:`process_input`) for a
Create an input processor (see :meth:`
_
process_input`) for a
specific model.
"""
return
functools
.
partial
(
self
.
process_input
,
model_config
)
vllm/multimodal/base.py
View file @
9b8c8ba1
...
...
@@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
is_list_of
,
json_map_leaves
from
vllm.utils
import
(
JSONTree
,
get_allowed_kwarg_only_overrides
,
is_list_of
,
json_map_leaves
)
logger
=
init_logger
(
__name__
)
...
...
@@ -256,11 +257,20 @@ class MultiModalPlugin(ABC):
model_cls
,
_
=
get_model_architecture
(
model_config
)
mapper
=
self
.
_input_mappers
.
get
(
model_cls
)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if
mapper
is
not
None
and
mapper
!=
self
.
_default_input_mapper
:
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
mapper
,
overrides
=
model_config
.
mm_processor_kwargs
)
else
:
mm_processor_kwargs
=
{}
if
mapper
is
None
:
raise
KeyError
(
f
"No input mapper in
{
self
}
is registered for "
f
"model class
{
model_cls
.
__name__
}
."
)
return
mapper
(
InputContext
(
model_config
),
data
)
return
mapper
(
InputContext
(
model_config
),
data
,
**
mm_processor_kwargs
)
@
abstractmethod
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
...
...
@@ -333,7 +343,10 @@ class MultiModalPlugin(ABC):
f
"for model class
{
model_cls
.
__name__
}
in
{
self
}
."
)
if
callable
(
max_mm_tokens
):
max_mm_tokens
=
max_mm_tokens
(
InputContext
(
model_config
))
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
max_mm_tokens
,
overrides
=
model_config
.
mm_processor_kwargs
)
max_mm_tokens
=
max_mm_tokens
(
InputContext
(
model_config
),
**
mm_processor_kwargs
)
self
.
_validate_max_multimodal_tokens
(
max_mm_tokens
)
...
...
vllm/multimodal/image.py
View file @
9b8c8ba1
...
...
@@ -6,7 +6,7 @@ from PIL import Image
from
vllm.config
import
ModelConfig
from
vllm.inputs.registry
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.
image_
processor
import
get_image_processor
from
vllm.transformers_utils.processor
import
get_image_processor
from
vllm.utils
import
is_list_of
from
.base
import
MultiModalData
,
MultiModalInputs
,
MultiModalPlugin
...
...
@@ -23,9 +23,14 @@ class ImagePlugin(MultiModalPlugin):
return
"image"
def
_get_hf_image_processor
(
self
,
model_config
:
ModelConfig
):
mm_processor_kwargs
=
({}
if
model_config
.
mm_processor_kwargs
is
None
else
model_config
.
mm_processor_kwargs
)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return
cached_get_image_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
trust_remote_code
=
model_config
.
trust_remote_code
,
**
mm_processor_kwargs
)
def
_default_input_mapper
(
self
,
...
...
@@ -37,6 +42,7 @@ class ImagePlugin(MultiModalPlugin):
# PIL image
if
isinstance
(
data
,
Image
.
Image
)
or
is_list_of
(
data
,
Image
.
Image
):
image_processor
=
self
.
_get_hf_image_processor
(
model_config
)
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
...
...
vllm/multimodal/registry.py
View file @
9b8c8ba1
...
...
@@ -138,6 +138,15 @@ class MultiModalRegistry:
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
# NOTE - we currently make the assumption that if a model has multiple
# supported modalities, they take the same kwargs. For the default,
# this could be an issue in the future if it falls back to two HF
# resources and we can't inspect the signature easily since it's
# getting initialized through the autoclass.
#
# If this is a problem in the future, we should revisit it, but since
# it potentially introduces a lot of complexity for a currently
# uncommon case, we do not for simplicity of both use & implementation
return
functools
.
partial
(
self
.
map_input
,
model_config
)
def
register_max_multimodal_tokens
(
...
...
vllm/multimodal/video.py
View file @
9b8c8ba1
...
...
@@ -6,7 +6,7 @@ import numpy as np
from
vllm.config
import
ModelConfig
from
vllm.inputs.registry
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.
image_
processor
import
get_video_processor
from
vllm.transformers_utils.processor
import
get_video_processor
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
is_list_of
...
...
@@ -37,9 +37,14 @@ class VideoPlugin(ImagePlugin):
return
"video"
def
_get_hf_video_processor
(
self
,
model_config
:
ModelConfig
):
mm_processor_kwargs
=
({}
if
model_config
.
mm_processor_kwargs
is
None
else
model_config
.
mm_processor_kwargs
)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return
cached_get_video_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
trust_remote_code
=
model_config
.
trust_remote_code
,
**
mm_processor_kwargs
)
def
_default_input_mapper
(
self
,
...
...
vllm/transformers_utils/image_processor.py
deleted
100644 → 0
View file @
d23679eb
from
typing
import
cast
def
get_video_processor
(
processor_name
:
str
,
trust_remote_code
:
bool
=
False
,
):
"""
Gets a processor for the given model name via HuggingFace.
"""
from
transformers
import
AutoProcessor
try
:
processor
=
AutoProcessor
.
from_pretrained
(
processor_name
)
video_processor
=
processor
.
video_processor
except
ValueError
as
e
:
if
not
trust_remote_code
:
err_msg
=
(
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
return
video_processor
def
get_image_processor
(
processor_name
:
str
,
*
args
,
trust_remote_code
:
bool
=
False
,
**
kwargs
,
):
"""Gets an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoImageProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
try
:
processor
=
AutoImageProcessor
.
from_pretrained
(
processor_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if
not
trust_remote_code
:
err_msg
=
(
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
return
cast
(
BaseImageProcessor
,
processor
)
vllm/transformers_utils/processor.py
View file @
9b8c8ba1
from
typing
import
cast
from
typing
import
Any
,
cast
def
get_processor
(
processor_name
:
str
,
*
args
,
*
args
:
Any
,
trust_remote_code
:
bool
=
False
,
**
kwargs
,
**
kwargs
:
Any
,
):
"""
Gets
a processor for the given model name via HuggingFace."""
"""
Load
a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoProcessor
...
...
@@ -35,3 +35,60 @@ def get_processor(
raise
e
return
cast
(
ProcessorMixin
,
processor
)
def
get_image_processor
(
processor_name
:
str
,
*
args
:
Any
,
trust_remote_code
:
bool
=
False
,
**
kwargs
:
Any
,
):
"""Load an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoImageProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
try
:
processor
=
AutoImageProcessor
.
from_pretrained
(
processor_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if
not
trust_remote_code
:
err_msg
=
(
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
return
cast
(
BaseImageProcessor
,
processor
)
def
get_video_processor
(
processor_name
:
str
,
*
args
:
Any
,
trust_remote_code
:
bool
=
False
,
**
kwargs
:
Any
,
):
"""Load a video processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers.image_processing_utils
import
BaseImageProcessor
processor
=
get_processor
(
processor_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
**
kwargs
,
)
return
cast
(
BaseImageProcessor
,
processor
.
video_processor
)
vllm/utils.py
View file @
9b8c8ba1
...
...
@@ -4,6 +4,7 @@ import contextlib
import
datetime
import
enum
import
gc
import
inspect
import
os
import
random
import
socket
...
...
@@ -1237,6 +1238,53 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return
await
task
(
*
args
,
**
kwargs
)
def
get_allowed_kwarg_only_overrides
(
callable
:
Callable
[...,
object
],
overrides
:
Optional
[
Dict
[
str
,
Any
]],
)
->
Dict
[
str
,
Any
]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
expanded to overwrite one or more keyword-only args. This is used in a
few places to handle custom processor overrides for multimodal models,
e.g., for profiling when processor options provided by the user
may affect the number of mm tokens per instance.
Args:
callable: Callable which takes 0 or more keyword only arguments.
overrides: Potential overrides to be used when invoking the callable.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
to overwrite one or more keyword only arguments when invoking the
callable.
"""
if
not
overrides
:
return
{}
allowed_override_names
=
[
name
for
name
,
param
in
inspect
.
signature
(
callable
).
parameters
.
items
()
if
param
.
kind
==
inspect
.
Parameter
.
KEYWORD_ONLY
]
# Drop any mm_processor_kwargs provided by the user that are
# not kwarg names accepted by the provided input processor.
filtered_overrides
=
{
kwarg_name
:
val
for
kwarg_name
,
val
in
overrides
.
items
()
if
kwarg_name
in
allowed_override_names
}
# If anything is dropped, log a warning
dropped_keys
=
overrides
.
keys
()
-
filtered_overrides
.
keys
()
if
dropped_keys
:
logger
.
warning
(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s"
,
dropped_keys
)
return
filtered_overrides
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment