Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3691b6b
Unverified
Commit
a3691b6b
authored
Oct 08, 2024
by
Alex Brooks
Committed by
GitHub
Oct 08, 2024
Browse files
[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
8c746226
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
438 additions
and
120 deletions
+438
-120
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+1
-0
tests/multimodal/test_processor_kwargs.py
tests/multimodal/test_processor_kwargs.py
+70
-40
tests/test_inputs.py
tests/test_inputs.py
+26
-0
tests/test_utils.py
tests/test_utils.py
+31
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+9
-0
vllm/inputs/data.py
vllm/inputs/data.py
+58
-9
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+51
-19
vllm/inputs/registry.py
vllm/inputs/registry.py
+10
-3
vllm/multimodal/audio.py
vllm/multimodal/audio.py
+2
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+20
-11
vllm/multimodal/image.py
vllm/multimodal/image.py
+18
-6
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+9
-4
vllm/multimodal/video.py
vllm/multimodal/video.py
+17
-7
vllm/sequence.py
vllm/sequence.py
+14
-0
vllm/utils.py
vllm/utils.py
+82
-13
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+5
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+4
-1
No files found.
examples/offline_inference_vision_language.py
View file @
a3691b6b
...
@@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
...
@@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs
=
{
"num_crops"
:
16
},
mm_processor_kwargs
=
{
"num_crops"
:
16
},
)
)
stop_token_ids
=
None
stop_token_ids
=
None
...
...
tests/multimodal/test_processor_kwargs.py
View file @
a3691b6b
...
@@ -74,11 +74,11 @@ def mm_model_cls():
...
@@ -74,11 +74,11 @@ def mm_model_cls():
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops
=
lambda
ctx
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
:
num_crops
get_num_crops
=
lambda
ctx
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
:
num_crops
custom_mapper
=
lambda
ctx
,
data
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
:
{
custom_mapper
=
lambda
ctx
,
data
,
*
,
num_crops
=
DEFAULT_NUM_CROPS
:
{
"
num_
pixels"
:
torch
.
zeros
(
size
=
(
1
,
num_crops
+
1
,
3
,
336
,
336
))
"pixel
_value
s"
:
torch
.
zeros
(
size
=
(
1
,
num_crops
+
1
,
3
,
336
,
336
))
}
}
### Test for default processor logic & mm_processor_kwargs wrapping
### Test
s
for default processor logic & mm_processor_kwargs wrapping
def
test_default_processor_is_a_noop
():
def
test_default_processor_is_a_noop
():
"""Ensure that by default, there is no processor override."""
"""Ensure that by default, there is no processor override."""
dummy_registry
=
InputRegistry
()
dummy_registry
=
InputRegistry
()
...
@@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
...
@@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
assert
proc_inputs
is
proc_outputs
assert
proc_inputs
is
proc_outputs
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
None
,
NUM_CROPS_OVERRIDE
])
def
_get_num_crops_info
(
init_num_crops
:
int
,
inference_num_crops
:
int
):
def
test_processor_default_kwargs
(
use_processor_mock
,
num_crops
):
"""Get the init / inference kwargs and expected num_crops for this test."""
"""Ensure input processors can use processor kwargs."""
dummy_registry
=
InputRegistry
()
# If we have a value for num_crops, pass the override value and make
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
# otherwise fall back to the default value
mm_processor
_kwargs
=
None
if
num_crops
is
None
else
{
init
_kwargs
=
None
if
init_
num_crops
is
None
else
{
"num_crops"
:
num_crops
"num_crops"
:
init_
num_crops
}
}
expected_num_crops
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
inference_kwargs
=
None
if
inference_num_crops
is
None
else
{
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
"num_crops"
:
inference_num_crops
mm_processor_kwargs
=
mm_processor_kwargs
)
}
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
if
inference_num_crops
is
not
None
:
expected_seq_count
=
inference_num_crops
elif
init_num_crops
is
not
None
:
expected_seq_count
=
init_num_crops
else
:
expected_seq_count
=
DEFAULT_NUM_CROPS
return
init_kwargs
,
inference_kwargs
,
expected_seq_count
@
pytest
.
mark
.
parametrize
(
"init_num_crops,inference_num_crops"
,
[
(
None
,
None
),
(
NUM_CROPS_OVERRIDE
,
None
),
(
DEFAULT_NUM_CROPS
,
NUM_CROPS_OVERRIDE
),
])
def
test_input_processor_kwargs
(
use_processor_mock
,
init_num_crops
,
inference_num_crops
):
"""Ensure input processors can use processor kwargs."""
dummy_registry
=
InputRegistry
()
init_kwargs
,
inference_kwargs
,
expected_seq_count
=
_get_num_crops_info
(
init_num_crops
,
inference_num_crops
)
num_crops_val
=
processor
(
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
))
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
mm_processor_kwargs
=
init_kwargs
)
assert
num_crops_val
==
expected_num_crops
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
num_crops_val
=
processor
(
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
,
mm_processor_kwargs
=
inference_kwargs
))
assert
num_crops_val
==
expected_seq_count
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
...
@@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs
):
mm_processor_kwargs
):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry
=
InputRegistry
()
dummy_registry
=
InputRegistry
()
# Should filter out the init time kwargs
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
ctx
=
build_model_context
(
DUMMY_MODEL_ID
,
mm_processor_kwargs
=
mm_processor_kwargs
)
mm_processor_kwargs
=
mm_processor_kwargs
)
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
processor
=
dummy_registry
.
create_input_processor
(
ctx
.
model_config
)
num_crops_val
=
processor
(
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
))
# Should filter out the inference time kwargs
num_crops_val
=
processor
(
LLMInputs
(
prompt_token_ids
=
[],
prompt
=
""
,
mm_processor_kwargs
=
mm_processor_kwargs
))
assert
num_crops_val
==
DEFAULT_NUM_CROPS
assert
num_crops_val
==
DEFAULT_NUM_CROPS
...
@@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
...
@@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
num_crops
+
1
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
num_crops
+
1
@
pytest
.
mark
.
parametrize
(
"num_crops"
,
[
None
,
NUM_CROPS_OVERRIDE
])
@
pytest
.
mark
.
parametrize
(
"init_num_crops,inference_num_crops"
,
[
def
test_custom_mapper_kwarg_overrides
(
image_assets
,
num_crops
):
(
None
,
None
),
(
NUM_CROPS_OVERRIDE
,
None
),
(
DEFAULT_NUM_CROPS
,
NUM_CROPS_OVERRIDE
),
])
def
test_custom_mapper_kwarg_overrides
(
image_assets
,
init_num_crops
,
inference_num_crops
):
"""Ensure custom mappers can use processor kwargs."""
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs
=
None
if
num_crops
is
None
else
{
init_kwargs
,
inference_kwargs
,
expected_seq_count
=
_get_num_crops_info
(
"num_crops"
:
num_crops
init_num_crops
,
inference_num_crops
)
}
expected_seq_count
=
DEFAULT_NUM_CROPS
if
num_crops
is
None
else
num_crops
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor
_kwargs
,
mm_processor_kwargs
=
init
_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
})
limit_mm_per_prompt
=
{
"image"
:
1
})
mm_registry
=
MultiModalRegistry
()
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image
=
image_assets
[
0
].
pil_image
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
image
}
mm_inputs
=
{
"image"
:
image
}
with
patch
.
object
(
# Patch the image registry for phi3v with our lambda that is compatible
mm_registry
.
_get_plugin
(
"image"
),
# with overrides, then ensure that calling the method correctly echos
"_default_input_mapper"
,
# our num_crops value back from the mm_processor_kwargs.
{
mm_model_cls
():
custom_mapper
},
mm_registry
.
_get_plugin
(
"image"
).
register_input_mapper
(
custom_mapper
)(
):
mm_model_cls
())
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
)
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
,
inference_kwargs
)
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
expected_seq_count
+
1
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
expected_seq_count
+
1
...
@@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
...
@@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
def
test_custom_mapper_with_sad_kwarg_overrides
(
image_assets
,
def
test_custom_mapper_with_sad_kwarg_overrides
(
image_assets
,
mm_processor_kwargs
):
mm_processor_kwargs
):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
ctx
=
build_model_context
(
MULTIMODAL_MODEL_ID
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
...
@@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
...
@@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_registry
=
MultiModalRegistry
()
mm_registry
=
MultiModalRegistry
()
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
mm_registry
.
init_mm_limits_per_prompt
(
ctx
.
model_config
)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image
=
image_assets
[
0
].
pil_image
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
image
}
mm_inputs
=
{
"image"
:
image
}
with
patch
.
object
(
# Patch the image registry for phi3v with our lambda that is compatible
mm_registry
.
_get_plugin
(
"image"
),
# with overrides, then ensure that calling the method correctly echos
"_default_input_mapper"
,
# our num_crops value back from the mm_processor_kwargs.
{
mm_model_cls
():
custom_mapper
},
mm_registry
.
_get_plugin
(
"image"
).
register_input_mapper
(
custom_mapper
)(
):
mm_model_cls
())
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
)
# Should filter out the inference time kwargs
mapped_inputs
=
mm_registry
.
map_input
(
ctx
.
model_config
,
mm_inputs
,
mm_processor_kwargs
=
mm_processor_kwargs
)
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
DEFAULT_NUM_CROPS
+
1
assert
mapped_inputs
[
"pixel_values"
].
shape
[
1
]
==
DEFAULT_NUM_CROPS
+
1
tests/test_inputs.py
View file @
a3691b6b
...
@@ -2,6 +2,7 @@ from typing import List
...
@@ -2,6 +2,7 @@ from typing import List
import
pytest
import
pytest
from
vllm.inputs
import
zip_enc_dec_prompts
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
STRING_INPUTS
=
[
STRING_INPUTS
=
[
...
@@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
...
@@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
def
test_parse_single_batch_string_slice
(
inputs_slice
:
slice
):
def
test_parse_single_batch_string_slice
(
inputs_slice
:
slice
):
assert
parse_and_batch_prompt
(
STRING_INPUTS
)[
inputs_slice
]
\
assert
parse_and_batch_prompt
(
STRING_INPUTS
)[
inputs_slice
]
\
==
parse_and_batch_prompt
(
STRING_INPUTS
[
inputs_slice
])
==
parse_and_batch_prompt
(
STRING_INPUTS
[
inputs_slice
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
'mm_processor_kwargs,expected_mm_kwargs'
,
[
(
None
,
[{},
{}]),
({},
[{},
{}]),
({
"foo"
:
100
},
[{
"foo"
:
100
},
{
"foo"
:
100
}]),
([{
"foo"
:
100
},
{
"bar"
:
200
}],
[{
"foo"
:
100
},
{
"bar"
:
200
}]),
])
# yapf: enable
def
test_zip_enc_dec_prompts
(
mm_processor_kwargs
,
expected_mm_kwargs
):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts
=
[
'An encoder prompt'
,
'Another encoder prompt'
]
decoder_prompts
=
[
'A decoder prompt'
,
'Another decoder prompt'
]
zipped_prompts
=
zip_enc_dec_prompts
(
encoder_prompts
,
decoder_prompts
,
mm_processor_kwargs
)
assert
len
(
zipped_prompts
)
==
len
(
encoder_prompts
)
==
len
(
decoder_prompts
)
for
enc
,
dec
,
exp_kwargs
,
zipped
in
zip
(
encoder_prompts
,
decoder_prompts
,
expected_mm_kwargs
,
zipped_prompts
):
assert
isinstance
(
zipped
,
dict
)
assert
len
(
zipped
.
keys
())
==
3
assert
zipped
[
'encoder_prompt'
]
==
enc
assert
zipped
[
'decoder_prompt'
]
==
dec
assert
zipped
[
'mm_processor_kwargs'
]
==
exp_kwargs
tests/test_utils.py
View file @
a3691b6b
...
@@ -7,7 +7,7 @@ from typing import AsyncIterator, Tuple
...
@@ -7,7 +7,7 @@ from typing import AsyncIterator, Tuple
import
pytest
import
pytest
from
vllm.utils
import
(
FlexibleArgumentParser
,
deprecate_kwargs
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
)
get_open_port
,
merge_async_iterators
,
supports_kw
)
from
.utils
import
error_on_warning
from
.utils
import
error_on_warning
...
@@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
...
@@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
parser_with_config
.
parse_args
(
parser_with_config
.
parse_args
(
[
'serve'
,
'--config'
,
'./data/test_config.yaml'
])
[
'serve'
,
'--config'
,
'./data/test_config.yaml'
])
# yapf: enable
@
pytest
.
mark
.
parametrize
(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported"
,
[
# Tests for positional argument support
(
lambda
foo
:
None
,
"foo"
,
True
,
True
,
False
),
(
lambda
foo
:
None
,
"foo"
,
False
,
True
,
True
),
# Tests for positional or keyword / keyword only
(
lambda
foo
=
100
:
None
,
"foo"
,
True
,
True
,
False
),
(
lambda
*
,
foo
:
None
,
"foo"
,
False
,
True
,
True
),
# Tests to make sure the names of variadic params are NOT supported
(
lambda
*
args
:
None
,
"args"
,
False
,
True
,
False
),
(
lambda
**
kwargs
:
None
,
"kwargs"
,
False
,
True
,
False
),
# Tests for if we allow var kwargs to add support
(
lambda
foo
:
None
,
"something_else"
,
False
,
True
,
False
),
(
lambda
foo
,
**
kwargs
:
None
,
"something_else"
,
False
,
True
,
True
),
(
lambda
foo
,
**
kwargs
:
None
,
"kwargs"
,
True
,
True
,
False
),
(
lambda
foo
,
**
kwargs
:
None
,
"foo"
,
True
,
True
,
False
),
])
# yapf: disable
def
test_supports_kw
(
callable
,
kw_name
,
requires_kw_only
,
allow_var_kwargs
,
is_supported
):
assert
supports_kw
(
callable
=
callable
,
kw_name
=
kw_name
,
requires_kw_only
=
requires_kw_only
,
allow_var_kwargs
=
allow_var_kwargs
)
==
is_supported
vllm/core/scheduler.py
View file @
a3691b6b
...
@@ -1309,6 +1309,7 @@ class Scheduler:
...
@@ -1309,6 +1309,7 @@ class Scheduler:
# `multi_modal_data` will be None.
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
mm_processor_kwargs
=
seq_group
.
mm_processor_kwargs
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
)
)
else
:
else
:
...
...
vllm/engine/llm_engine.py
View file @
a3691b6b
...
@@ -811,6 +811,13 @@ class LLMEngine:
...
@@ -811,6 +811,13 @@ class LLMEngine:
)
)
processed_inputs
=
self
.
input_processor
(
preprocessed_inputs
)
processed_inputs
=
self
.
input_processor
(
preprocessed_inputs
)
# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
"mm_processor_kwargs"
)
self
.
_add_processed_request
(
self
.
_add_processed_request
(
request_id
=
request_id
,
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
processed_inputs
=
processed_inputs
,
...
...
vllm/entrypoints/llm.py
View file @
a3691b6b
...
@@ -472,6 +472,7 @@ class LLM:
...
@@ -472,6 +472,7 @@ class LLM:
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""
"""
Generate responses for a chat conversation.
Generate responses for a chat conversation.
...
@@ -501,6 +502,8 @@ class LLM:
...
@@ -501,6 +502,8 @@ class LLM:
continue_final_message: If True, continues the final message in
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
Returns:
A list of ``RequestOutput`` objects containing the generated
A list of ``RequestOutput`` objects containing the generated
...
@@ -522,6 +525,9 @@ class LLM:
...
@@ -522,6 +525,9 @@ class LLM:
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
model_config
=
self
.
llm_engine
.
get_model_config
()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation
,
mm_data
=
parse_chat_messages
(
conversation
,
mm_data
=
parse_chat_messages
(
msgs
,
model_config
,
tokenizer
)
msgs
,
model_config
,
tokenizer
)
...
@@ -554,6 +560,9 @@ class LLM:
...
@@ -554,6 +560,9 @@ class LLM:
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
prompt
[
"multi_modal_data"
]
=
mm_data
prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_processor_kwargs
is
not
None
:
prompt
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
prompts
.
append
(
prompt
)
prompts
.
append
(
prompt
)
return
self
.
generate
(
return
self
.
generate
(
...
...
vllm/inputs/data.py
View file @
a3691b6b
from
typing
import
(
TYPE_CHECKING
,
Generic
,
Iterable
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
Iterable
,
List
,
Union
)
Optional
,
Tuple
,
Union
)
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
...
@@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
...
@@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
if the model supports it.
if the model supports it.
"""
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class
TokensPrompt
(
TypedDict
):
class
TokensPrompt
(
TypedDict
):
"""Schema for a tokenized prompt."""
"""Schema for a tokenized prompt."""
...
@@ -32,6 +40,14 @@ class TokensPrompt(TypedDict):
...
@@ -32,6 +40,14 @@ class TokensPrompt(TypedDict):
if the model supports it.
if the model supports it.
"""
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
"""
...
@@ -74,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -74,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
according to any of the :class:`SingletonPrompt` schemas,
according to any of the :class:`SingletonPrompt` schemas,
and are not required to have the same schema.
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
be used as an input to a decoder-only model,
...
@@ -87,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -87,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt
:
Optional
[
_T2_co
]
decoder_prompt
:
Optional
[
_T2_co
]
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
PromptType
=
Union
[
SingletonPrompt
,
ExplicitEncoderDecoderPrompt
]
PromptType
=
Union
[
SingletonPrompt
,
ExplicitEncoderDecoderPrompt
]
"""
"""
...
@@ -121,6 +141,14 @@ class LLMInputs(TypedDict):
...
@@ -121,6 +141,14 @@ class LLMInputs(TypedDict):
if the model supports it.
if the model supports it.
"""
"""
mm_processor_kwargs
:
NotRequired
[
Optional
[
Dict
[
str
,
Any
]]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class
EncoderDecoderLLMInputs
(
LLMInputs
):
class
EncoderDecoderLLMInputs
(
LLMInputs
):
"""
"""
...
@@ -152,22 +180,43 @@ _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
...
@@ -152,22 +180,43 @@ _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def
build_explicit_enc_dec_prompt
(
def
build_explicit_enc_dec_prompt
(
encoder_prompt
:
_T1
,
encoder_prompt
:
_T1
,
decoder_prompt
:
Optional
[
_T2
],
decoder_prompt
:
Optional
[
_T2
],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]:
)
->
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]:
return
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
,
if
mm_processor_kwargs
is
None
:
decoder_prompt
=
decoder_prompt
)
mm_processor_kwargs
=
{}
return
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
,
decoder_prompt
=
decoder_prompt
,
mm_processor_kwargs
=
mm_processor_kwargs
)
def
zip_enc_dec_prompts
(
def
zip_enc_dec_prompts
(
enc_prompts
:
Iterable
[
_T1
],
enc_prompts
:
Iterable
[
_T1
],
dec_prompts
:
Iterable
[
Optional
[
_T2
]],
dec_prompts
:
Iterable
[
Optional
[
_T2
]],
mm_processor_kwargs
:
Optional
[
Union
[
Iterable
[
Dict
[
str
,
Any
]],
Dict
[
str
,
Any
]]]
=
None
,
)
->
List
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
)
->
List
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
"""
"""
Zip encoder and decoder prompts together into a list of
Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances.
:class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
"""
may also be provided; if a dict is passed, the same dictionary will be
used for every encoder/decoder prompt. If an iterable is provided, it will
be zipped with the encoder/decoder prompts.
"""
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
if
isinstance
(
mm_processor_kwargs
,
Dict
):
return
[
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
,
mm_processor_kwargs
)
for
(
encoder_prompt
,
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
]
return
[
return
[
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
)
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
,
for
(
encoder_prompt
,
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
mm_proc_kwargs
)
for
(
encoder_prompt
,
decoder_prompt
,
mm_proc_kwargs
)
in
zip
(
enc_prompts
,
dec_prompts
,
mm_processor_kwargs
)
]
]
...
...
vllm/inputs/preprocess.py
View file @
a3691b6b
import
asyncio
import
asyncio
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
...
@@ -20,9 +20,11 @@ if TYPE_CHECKING:
...
@@ -20,9 +20,11 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
PromptComponents
=
Tuple
[
Optional
[
str
],
List
[
int
],
PromptComponents
=
Tuple
[
Optional
[
str
],
List
[
int
],
Optional
[
"MultiModalDataDict"
]]
Optional
[
"MultiModalDataDict"
],
Optional
[
Dict
[
str
,
Any
]]]
DecoderPromptComponents
=
Tuple
[
Optional
[
str
],
Optional
[
List
[
int
]],
DecoderPromptComponents
=
Tuple
[
Optional
[
str
],
Optional
[
List
[
int
]],
Optional
[
"MultiModalDataDict"
]]
Optional
[
"MultiModalDataDict"
],
Optional
[
Dict
[
str
,
Any
]]]
class
InputPreprocessor
:
class
InputPreprocessor
:
...
@@ -227,6 +229,7 @@ class InputPreprocessor:
...
@@ -227,6 +229,7 @@ class InputPreprocessor:
* prompt
* prompt
* prompt_token_ids
* prompt_token_ids
* multi_modal_data
* multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
'''
'''
parsed
=
parse_singleton_prompt
(
prompt
)
parsed
=
parse_singleton_prompt
(
prompt
)
...
@@ -239,10 +242,12 @@ class InputPreprocessor:
...
@@ -239,10 +242,12 @@ class InputPreprocessor:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
multi_modal_data
=
None
multi_modal_data
=
None
mm_processor_kwargs
=
None
elif
parsed
[
"type"
]
==
"tokens"
:
elif
parsed
[
"type"
]
==
"tokens"
:
prompt_text
=
None
prompt_text
=
None
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
elif
parsed
[
"type"
]
==
"text"
:
elif
parsed
[
"type"
]
==
"text"
:
prompt_text
=
parsed
[
"content"
][
"prompt"
]
prompt_text
=
parsed
[
"content"
][
"prompt"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_token_ids
=
self
.
_tokenize_prompt
(
...
@@ -251,10 +256,12 @@ class InputPreprocessor:
...
@@ -251,10 +256,12 @@ class InputPreprocessor:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
else
:
else
:
assert_never
(
parsed
)
assert_never
(
parsed
)
return
prompt_text
,
prompt_token_ids
,
multi_modal_data
return
(
prompt_text
,
prompt_token_ids
,
multi_modal_data
,
mm_processor_kwargs
)
async
def
_extract_prompt_components_async
(
async
def
_extract_prompt_components_async
(
self
,
self
,
...
@@ -273,10 +280,12 @@ class InputPreprocessor:
...
@@ -273,10 +280,12 @@ class InputPreprocessor:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
multi_modal_data
=
None
multi_modal_data
=
None
mm_processor_kwargs
=
None
elif
parsed
[
"type"
]
==
"tokens"
:
elif
parsed
[
"type"
]
==
"tokens"
:
prompt_text
=
None
prompt_text
=
None
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
prompt_token_ids
=
parsed
[
"content"
][
"prompt_token_ids"
]
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
elif
parsed
[
"type"
]
==
"text"
:
elif
parsed
[
"type"
]
==
"text"
:
prompt_text
=
parsed
[
"content"
][
"prompt"
]
prompt_text
=
parsed
[
"content"
][
"prompt"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
...
@@ -285,18 +294,21 @@ class InputPreprocessor:
...
@@ -285,18 +294,21 @@ class InputPreprocessor:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
multi_modal_data
=
parsed
[
"content"
].
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
)
else
:
else
:
assert_never
(
parsed
)
assert_never
(
parsed
)
return
prompt_text
,
prompt_token_ids
,
multi_modal_data
return
(
prompt_text
,
prompt_token_ids
,
multi_modal_data
,
mm_processor_kwargs
)
def
_build_enc_dec_llm_inputs
(
def
_build_enc_dec_llm_inputs
(
self
,
self
,
encoder_comps
:
PromptComponents
,
encoder_comps
:
PromptComponents
,
decoder_comps
:
DecoderPromptComponents
,
decoder_comps
:
DecoderPromptComponents
,
mm_processor_kwargs
:
Dict
[
str
,
Any
],
)
->
EncoderDecoderLLMInputs
:
)
->
EncoderDecoderLLMInputs
:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
,
_
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
,
_
=
decoder_comps
if
decoder_mm_data
is
not
None
:
if
decoder_mm_data
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -314,6 +326,7 @@ class InputPreprocessor:
...
@@ -314,6 +326,7 @@ class InputPreprocessor:
prompt_token_ids
=
decoder_prompt_ids
,
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
prompt
=
decoder_prompt
,
multi_modal_data
=
decoder_mm_data
,
multi_modal_data
=
decoder_mm_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
encoder_prompt
=
encoder_prompt
,
encoder_multi_modal_data
=
encoder_mm_data
,
encoder_multi_modal_data
=
encoder_mm_data
,
...
@@ -367,21 +380,30 @@ class InputPreprocessor:
...
@@ -367,21 +380,30 @@ class InputPreprocessor:
)
)
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_comps
=
None
,
None
,
None
decoder_comps
=
None
,
None
,
None
,
None
else
:
else
:
decoder_comps
=
self
.
_extract_prompt_components
(
decoder_comps
=
self
.
_extract_prompt_components
(
decoder_input
,
decoder_input
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs
=
prompt
.
get
(
"mm_processor_kwargs"
,
{})
else
:
else
:
encoder_comps
=
self
.
_extract_prompt_components
(
encoder_comps
=
self
.
_extract_prompt_components
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
# If there are no decoder components, we assume the
decoder_comps
=
None
,
None
,
None
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs
=
encoder_comps
[
-
1
]
if
encoder_comps
[
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
-
1
]
is
not
None
else
{}
decoder_comps
=
None
,
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
,
mm_processor_kwargs
,
)
async
def
_process_encoder_decoder_prompt_async
(
async
def
_process_encoder_decoder_prompt_async
(
self
,
self
,
...
@@ -400,7 +422,7 @@ class InputPreprocessor:
...
@@ -400,7 +422,7 @@ class InputPreprocessor:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
encoder_comps
=
await
encoder_task
encoder_comps
=
await
encoder_task
decoder_comps
=
None
,
None
,
None
decoder_comps
=
None
,
None
,
None
,
None
else
:
else
:
decoder_task
=
self
.
_extract_prompt_components_async
(
decoder_task
=
self
.
_extract_prompt_components_async
(
decoder_input
,
decoder_input
,
...
@@ -409,29 +431,39 @@ class InputPreprocessor:
...
@@ -409,29 +431,39 @@ class InputPreprocessor:
encoder_comps
,
decoder_comps
=
await
asyncio
.
gather
(
encoder_comps
,
decoder_comps
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
encoder_task
,
decoder_task
)
mm_processor_kwargs
=
prompt
[
"mm_processor_kwargs"
]
else
:
else
:
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
# If there are no decoder components, we assume the
decoder_comps
=
None
,
None
,
None
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs
=
encoder_comps
[
-
1
]
if
encoder_comps
[
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
-
1
]
is
not
None
else
{}
decoder_comps
=
None
,
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
,
mm_processor_kwargs
,
)
def
_build_decoder_only_llm_inputs
(
def
_build_decoder_only_llm_inputs
(
self
,
self
,
prompt_comps
:
PromptComponents
,
prompt_comps
:
PromptComponents
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
LLMInputs
:
)
->
LLMInputs
:
prompt
,
prompt_token_ids
,
multi_modal_data
=
prompt_comps
(
prompt
,
prompt_token_ids
,
multi_modal_data
,
mm_processor_kwargs
)
=
prompt_comps
prompt_token_ids
=
self
.
_apply_prompt_adapter
(
prompt_token_ids
=
self
.
_apply_prompt_adapter
(
prompt_token_ids
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_token_ids
,
prompt_adapter_request
=
prompt_adapter_request
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
)
def
_process_decoder_only_prompt
(
def
_process_decoder_only_prompt
(
self
,
self
,
...
...
vllm/inputs/registry.py
View file @
a3691b6b
...
@@ -9,7 +9,8 @@ from transformers import PretrainedConfig
...
@@ -9,7 +9,8 @@ from transformers import PretrainedConfig
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_allowed_kwarg_only_overrides
,
print_warning_once
from
vllm.utils
import
(
get_allowed_kwarg_only_overrides
,
print_warning_once
,
resolve_mm_processor_kwargs
)
from
.data
import
LLMInputs
from
.data
import
LLMInputs
...
@@ -293,8 +294,14 @@ class InputRegistry:
...
@@ -293,8 +294,14 @@ class InputRegistry:
model_cls
,
_
=
get_model_architecture
(
model_config
)
model_cls
,
_
=
get_model_architecture
(
model_config
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
# Handle multimodal processor kwargs with priority:
processor
,
overrides
=
model_config
.
mm_processor_kwargs
)
# Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs
=
resolve_mm_processor_kwargs
(
model_config
.
mm_processor_kwargs
,
inputs
.
get
(
"mm_processor_kwargs"
),
processor
,
)
return
processor
(
InputContext
(
model_config
),
inputs
,
return
processor
(
InputContext
(
model_config
),
inputs
,
**
mm_processor_kwargs
)
**
mm_processor_kwargs
)
...
...
vllm/multimodal/audio.py
View file @
a3691b6b
...
@@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin):
...
@@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin):
def
get_data_key
(
self
)
->
str
:
def
get_data_key
(
self
)
->
str
:
return
"audio"
return
"audio"
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
data
:
object
,
data
:
object
)
->
MultiModalInputs
:
**
mm_processor_kwargs
)
->
MultiModalInputs
:
raise
NotImplementedError
(
"There is no default audio input mapper"
)
raise
NotImplementedError
(
"There is no default audio input mapper"
)
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
...
...
vllm/multimodal/base.py
View file @
a3691b6b
import
sys
import
sys
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
typing
import
(
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
import
numpy
as
np
import
numpy
as
np
...
@@ -15,7 +15,7 @@ from vllm.config import ModelConfig
...
@@ -15,7 +15,7 @@ from vllm.config import ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
JSONTree
,
get_allowed_kwarg_only_overrides
,
is_list_of
,
from
vllm.utils
import
(
JSONTree
,
get_allowed_kwarg_only_overrides
,
is_list_of
,
json_map_leaves
)
json_map_leaves
,
resolve_mm_processor_kwargs
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -200,6 +200,7 @@ class MultiModalPlugin(ABC):
...
@@ -200,6 +200,7 @@ class MultiModalPlugin(ABC):
self
,
self
,
ctx
:
InputContext
,
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
data
:
MultiModalData
[
object
],
**
mm_processor_kwargs
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
"""
"""
Return a dictionary to be passed as keyword arguments to
Return a dictionary to be passed as keyword arguments to
...
@@ -243,7 +244,8 @@ class MultiModalPlugin(ABC):
...
@@ -243,7 +244,8 @@ class MultiModalPlugin(ABC):
return
wrapper
return
wrapper
def
map_input
(
self
,
model_config
:
ModelConfig
,
def
map_input
(
self
,
model_config
:
ModelConfig
,
data
:
MultiModalData
[
object
])
->
MultiModalInputs
:
data
:
MultiModalData
[
object
],
mm_processor_kwargs
:
Dict
[
str
,
Any
])
->
MultiModalInputs
:
"""
"""
Transform the data into a dictionary of model inputs using the
Transform the data into a dictionary of model inputs using the
input mapper registered for that model.
input mapper registered for that model.
...
@@ -263,19 +265,26 @@ class MultiModalPlugin(ABC):
...
@@ -263,19 +265,26 @@ class MultiModalPlugin(ABC):
model_cls
,
_
=
get_model_architecture
(
model_config
)
model_cls
,
_
=
get_model_architecture
(
model_config
)
mapper
=
self
.
_input_mappers
.
get
(
model_cls
)
mapper
=
self
.
_input_mappers
.
get
(
model_cls
)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if
mapper
is
not
None
and
mapper
!=
self
.
_default_input_mapper
:
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
mapper
,
overrides
=
model_config
.
mm_processor_kwargs
)
else
:
mm_processor_kwargs
=
{}
if
mapper
is
None
:
if
mapper
is
None
:
raise
KeyError
(
f
"No input mapper in
{
self
}
is registered for "
raise
KeyError
(
f
"No input mapper in
{
self
}
is registered for "
f
"model class
{
model_cls
.
__name__
}
."
)
f
"model class
{
model_cls
.
__name__
}
."
)
# In the case of the default mapper, we have to get resource
# processor through its HuggingFace autoclass; since this goes
# through **kwargs, we can't inspect it the same way, so we allow
# drop mm_processor_kwargs based on signature inspection
# if we're using the default mapper.
#
# This should be safe in general due to the sanitation, since the
# transformers resource should filter unused kwargs anyway.
uses_default_mapper
=
mapper
==
self
.
_default_input_mapper
mm_processor_kwargs
=
resolve_mm_processor_kwargs
(
model_config
.
mm_processor_kwargs
,
mm_processor_kwargs
,
callable
=
mapper
,
allow_var_kwargs
=
uses_default_mapper
,
)
return
mapper
(
InputContext
(
model_config
),
data
,
**
mm_processor_kwargs
)
return
mapper
(
InputContext
(
model_config
),
data
,
**
mm_processor_kwargs
)
@
abstractmethod
@
abstractmethod
...
...
vllm/multimodal/image.py
View file @
a3691b6b
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
Optional
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
...
@@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin):
...
@@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin):
def
get_data_key
(
self
)
->
str
:
def
get_data_key
(
self
)
->
str
:
return
"image"
return
"image"
def
_get_hf_image_processor
(
self
,
model_config
:
ModelConfig
):
def
_get_hf_image_processor
(
mm_processor_kwargs
=
({}
if
model_config
.
mm_processor_kwargs
is
None
self
,
else
model_config
.
mm_processor_kwargs
)
model_config
:
ModelConfig
,
# We don't explicitly check kwarg overrides to the HF class
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
# since the automodel just takes kwargs, so we can't inspect it
):
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
return
cached_get_image_processor
(
return
cached_get_image_processor
(
model_config
.
model
,
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
trust_remote_code
=
model_config
.
trust_remote_code
,
...
@@ -37,6 +40,7 @@ class ImagePlugin(MultiModalPlugin):
...
@@ -37,6 +40,7 @@ class ImagePlugin(MultiModalPlugin):
self
,
self
,
ctx
:
InputContext
,
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
data
:
MultiModalData
[
object
],
**
mm_processor_kwargs
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
...
@@ -46,12 +50,20 @@ class ImagePlugin(MultiModalPlugin):
...
@@ -46,12 +50,20 @@ class ImagePlugin(MultiModalPlugin):
# PIL image
# PIL image
if
isinstance
(
data
,
Image
.
Image
)
or
is_list_of
(
data
,
Image
.
Image
):
if
isinstance
(
data
,
Image
.
Image
)
or
is_list_of
(
data
,
Image
.
Image
):
image_processor
=
self
.
_get_hf_image_processor
(
model_config
)
image_processor
=
self
.
_get_hf_image_processor
(
model_config
,
mm_processor_kwargs
,
)
if
image_processor
is
None
:
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
"to process the image object"
)
try
:
try
:
# NOTE: It may make sense to forward the mm_processor_kwargs
# here too. For now, to keep it simple, we only allow it be
# used for the initialization call though, just in case the
# signatures of the preprocessor initializer don't match
# preprocess()
batch_data
=
image_processor
\
batch_data
=
image_processor
\
.
preprocess
(
data
,
return_tensors
=
"pt"
)
\
.
preprocess
(
data
,
return_tensors
=
"pt"
)
\
.
data
.
data
...
...
vllm/multimodal/registry.py
View file @
a3691b6b
import
functools
import
functools
from
collections
import
UserDict
from
collections
import
UserDict
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -96,8 +96,12 @@ class MultiModalRegistry:
...
@@ -96,8 +96,12 @@ class MultiModalRegistry:
"""
"""
return
self
.
register_input_mapper
(
"image"
,
mapper
)
return
self
.
register_input_mapper
(
"image"
,
mapper
)
def
map_input
(
self
,
model_config
:
ModelConfig
,
def
map_input
(
data
:
MultiModalDataDict
)
->
MultiModalInputs
:
self
,
model_config
:
ModelConfig
,
data
:
MultiModalDataDict
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
MultiModalInputs
:
"""
"""
Apply an input mapper to the data passed to the model.
Apply an input mapper to the data passed to the model.
...
@@ -123,7 +127,8 @@ class MultiModalRegistry:
...
@@ -123,7 +127,8 @@ class MultiModalRegistry:
f
"`--limit-mm-per-prompt`, but found
{
num_items
}
items "
f
"`--limit-mm-per-prompt`, but found
{
num_items
}
items "
"in the same prompt."
)
"in the same prompt."
)
input_dict
=
plugin
.
map_input
(
model_config
,
data_value
)
input_dict
=
plugin
.
map_input
(
model_config
,
data_value
,
mm_processor_kwargs
)
for
input_key
,
input_tensor
in
input_dict
.
items
():
for
input_key
,
input_tensor
in
input_dict
.
items
():
if
input_key
in
merged_dict
:
if
input_key
in
merged_dict
:
raise
ValueError
(
f
"The input mappers (keys=
{
set
(
data
)
}
) "
raise
ValueError
(
f
"The input mappers (keys=
{
set
(
data
)
}
) "
...
...
vllm/multimodal/video.py
View file @
a3691b6b
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin):
...
@@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin):
def
get_data_key
(
self
)
->
str
:
def
get_data_key
(
self
)
->
str
:
return
"video"
return
"video"
def
_get_hf_video_processor
(
self
,
model_config
:
ModelConfig
):
def
_get_hf_video_processor
(
mm_processor_kwargs
=
({}
if
model_config
.
mm_processor_kwargs
is
None
self
,
else
model_config
.
mm_processor_kwargs
)
model_config
:
ModelConfig
,
# We don't explicitly check kwarg overrides to the HF class
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
# since the automodel just takes kwargs, so we can't inspect it
):
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
return
cached_get_video_processor
(
return
cached_get_video_processor
(
model_config
.
model
,
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
trust_remote_code
=
model_config
.
trust_remote_code
,
...
@@ -50,16 +52,24 @@ class VideoPlugin(ImagePlugin):
...
@@ -50,16 +52,24 @@ class VideoPlugin(ImagePlugin):
self
,
self
,
ctx
:
InputContext
,
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
data
:
MultiModalData
[
object
],
**
mm_processor_kwargs
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
# single video input as np.ndarray
# single video input as np.ndarray
if
isinstance
(
data
,
np
.
ndarray
):
if
isinstance
(
data
,
np
.
ndarray
):
video_processor
=
self
.
_get_hf_video_processor
(
model_config
)
video_processor
=
self
.
_get_hf_video_processor
(
model_config
,
mm_processor_kwargs
,
)
if
video_processor
is
None
:
if
video_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
"to process the image object"
)
try
:
try
:
# NOTE: Similar to image; it may be a good idea to filter and
# pass mm_processor_kwargs here too, but for now we don't to
# avoid extra complexity if the initializer and preprocess
# signatures of the processor don't align
batch_data
=
video_processor
(
data
,
return_tensors
=
"pt"
).
data
batch_data
=
video_processor
(
data
,
return_tensors
=
"pt"
).
data
except
Exception
:
except
Exception
:
logger
.
error
(
"Failed to process image (%s)"
,
data
)
logger
.
error
(
"Failed to process image (%s)"
,
data
)
...
...
vllm/sequence.py
View file @
a3691b6b
...
@@ -481,6 +481,10 @@ class Sequence:
...
@@ -481,6 +481,10 @@ class Sequence:
EncoderDecoderLLMInputs
,
EncoderDecoderLLMInputs
,
inputs
).
get
(
"encoder_multi_modal_data"
))
or
{}
inputs
).
get
(
"encoder_multi_modal_data"
))
or
{}
@
property
def
mm_processor_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
inputs
.
get
(
"mm_processor_kwargs"
)
or
{}
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
...
@@ -710,6 +714,14 @@ class SequenceGroup:
...
@@ -710,6 +714,14 @@ class SequenceGroup:
# We use the multi-modal data of an arbitrary sequence.
# We use the multi-modal data of an arbitrary sequence.
return
self
.
seqs
[
0
].
multi_modal_data
return
self
.
seqs
[
0
].
multi_modal_data
@
property
def
mm_processor_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
# As with multi-modal data, all sequences in the group should have the
# same processor kwargs (i.e., mm_processor_kwargs are optionally
# provided per request; note that are independent of whether the model
# decoder-only or an encoder-decoder).
return
self
.
seqs
[
0
].
mm_processor_kwargs
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
...
@@ -949,6 +961,7 @@ class SequenceGroupMetadata(
...
@@ -949,6 +961,7 @@ class SequenceGroupMetadata(
used in prefix caching.
used in prefix caching.
state: Internal state tied to this sequence group.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
mm_processor_kwargs: Multimodal input processor / mapper overrides.
encoder_seq_data: Optional sequence data for encoder prompt
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
unless you are working with an encoder/decoder
...
@@ -975,6 +988,7 @@ class SequenceGroupMetadata(
...
@@ -975,6 +988,7 @@ class SequenceGroupMetadata(
# "MultiModalDataDict" types. We have to use Any due to msgspec
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
# doesn't allow to have union of 2 different dicts.
multi_modal_data
:
Optional
[
Any
]
=
None
multi_modal_data
:
Optional
[
Any
]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
cross_block_table
:
Optional
[
List
[
int
]]
=
None
cross_block_table
:
Optional
[
List
[
int
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
...
...
vllm/utils.py
View file @
a3691b6b
...
@@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
...
@@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return
await
task
(
*
args
,
**
kwargs
)
return
await
task
(
*
args
,
**
kwargs
)
def
supports_kw
(
callable
:
Callable
[...,
object
],
kw_name
:
str
)
->
bool
:
def
supports_kw
(
callable
:
Callable
[...,
object
],
kw_name
:
str
,
requires_kw_only
:
bool
=
False
,
allow_var_kwargs
:
bool
=
True
,
)
->
bool
:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params
=
inspect
.
signature
(
callable
).
parameters
params
=
inspect
.
signature
(
callable
).
parameters
if
kw_name
in
params
:
if
not
params
:
return
True
return
False
param_val
=
params
.
get
(
kw_name
)
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types
=
set
((
inspect
.
Parameter
.
POSITIONAL_ONLY
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
,
inspect
.
Parameter
.
KEYWORD_ONLY
))
if
param_val
:
is_sig_param
=
param_val
.
kind
in
passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if
(
requires_kw_only
and
is_sig_param
and
param_val
.
kind
!=
inspect
.
Parameter
.
KEYWORD_ONLY
):
return
False
if
((
requires_kw_only
and
param_val
.
kind
==
inspect
.
Parameter
.
KEYWORD_ONLY
)
or
(
not
requires_kw_only
and
is_sig_param
)):
return
True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if
allow_var_kwargs
:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param
=
params
[
next
(
reversed
(
params
))]
# type: ignore
return
(
last_param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
and
last_param
.
name
!=
kw_name
)
return
False
def
resolve_mm_processor_kwargs
(
init_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
inference_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
callable
:
Callable
[...,
object
],
allow_var_kwargs
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
giving priority to inference_kwargs if there are any collisions.
In the case that no kwarg overrides are provided, returns an empty
dict so that it can still be kwarg expanded into the callable later on.
If allow_var_kwargs=True, allows for things that can be expanded into
kwargs as long as they aren't naming collision for var_kwargs or potential
positional arguments.
"""
# Filter inference time multimodal processor kwargs provided
runtime_mm_kwargs
=
get_allowed_kwarg_only_overrides
(
callable
,
overrides
=
inference_kwargs
,
allow_var_kwargs
=
allow_var_kwargs
)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs
=
get_allowed_kwarg_only_overrides
(
callable
,
overrides
=
init_kwargs
,
allow_var_kwargs
=
allow_var_kwargs
)
return
any
(
param
.
kind
==
inspect
.
Parameter
.
VAR_KEYWORD
# Merge the final processor kwargs, prioritizing inference
for
param
in
params
.
values
())
# time values over the initialization time values.
mm_processor_kwargs
=
{
**
init_mm_kwargs
,
**
runtime_mm_kwargs
}
return
mm_processor_kwargs
def
get_allowed_kwarg_only_overrides
(
def
get_allowed_kwarg_only_overrides
(
callable
:
Callable
[...,
object
],
callable
:
Callable
[...,
object
],
overrides
:
Optional
[
Dict
[
str
,
Any
]],
overrides
:
Optional
[
Dict
[
str
,
Any
]],
allow_var_kwargs
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
"""
"""
Given a callable which has one or more keyword only params and a dict
Given a callable which has one or more keyword only params and a dict
...
@@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
...
@@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
Args:
Args:
callable: Callable which takes 0 or more keyword only arguments.
callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable.
overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns:
Returns:
Dictionary containing the kwargs to be leveraged which may be used
Dictionary containing the kwargs to be leveraged which may be used
...
@@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
...
@@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
if
not
overrides
:
if
not
overrides
:
return
{}
return
{}
allowed_override_names
=
[
# Drop any mm_processor_kwargs provided by the user that
name
for
name
,
param
in
inspect
.
signature
(
callable
).
parameters
.
items
()
# are not kwargs, unless it can fit it var_kwargs param
if
param
.
kind
==
inspect
.
Parameter
.
KEYWORD_ONLY
]
# Drop any mm_processor_kwargs provided by the user that are
# not kwarg names accepted by the provided input processor.
filtered_overrides
=
{
filtered_overrides
=
{
kwarg_name
:
val
kwarg_name
:
val
for
kwarg_name
,
val
in
overrides
.
items
()
for
kwarg_name
,
val
in
overrides
.
items
()
if
kwarg_name
in
allowed_override_names
if
supports_kw
(
callable
,
kwarg_name
,
requires_kw_only
=
True
,
allow_var_kwargs
=
allow_var_kwargs
)
}
}
# If anything is dropped, log a warning
# If anything is dropped, log a warning
...
...
vllm/worker/cpu_model_runner.py
View file @
a3691b6b
...
@@ -148,8 +148,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
...
@@ -148,8 +148,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
)
)
def
_compute_multi_modal_input
(
self
,
seq_data
:
SequenceData
,
mm_data
,
def
_compute_multi_modal_input
(
self
,
seq_data
:
SequenceData
,
mm_data
,
computed_len
:
int
):
computed_len
:
int
,
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
mm_processor_kwargs
:
Dict
[
str
,
Any
]):
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
,
mm_processor_kwargs
)
# special processing for mrope position deltas.
# special processing for mrope position deltas.
mrope_positions
=
None
mrope_positions
=
None
...
@@ -210,7 +211,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
...
@@ -210,7 +211,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
mrope_positions
=
None
mrope_positions
=
None
if
(
mm_data
:
=
seq_group_metadata
.
multi_modal_data
):
if
(
mm_data
:
=
seq_group_metadata
.
multi_modal_data
):
mm_kwargs
,
mrope_positions
=
self
.
_compute_multi_modal_input
(
mm_kwargs
,
mrope_positions
=
self
.
_compute_multi_modal_input
(
seq_data
,
mm_data
,
computed_len
)
seq_data
,
mm_data
,
computed_len
,
seq_group_metadata
.
mm_processor_kwargs
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
# Token position ids
# Token position ids
...
...
vllm/worker/model_runner.py
View file @
a3691b6b
...
@@ -640,7 +640,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -640,7 +640,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
not
mm_data
:
if
not
mm_data
:
return
return
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
,
mm_processor_kwargs
=
seq_group_metadata
.
mm_processor_kwargs
)
inter_data
.
multi_modal_inputs
=
mm_kwargs
inter_data
.
multi_modal_inputs
=
mm_kwargs
# special processing for mrope position deltas.
# special processing for mrope position deltas.
...
...
vllm/worker/neuron_model_runner.py
View file @
a3691b6b
...
@@ -153,7 +153,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -153,7 +153,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
mm_data
=
seq_group_metadata
.
multi_modal_data
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
if
mm_data
:
# Process multi-modal data
# Process multi-modal data
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
,
mm_processor_kwargs
=
seq_group_metadata
.
mm_processor_kwargs
,
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
max_seq_len
=
max
(
seq_lens
)
max_seq_len
=
max
(
seq_lens
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment