Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b10609e6
Unverified
Commit
b10609e6
authored
Dec 15, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 15, 2024
Browse files
[Misc] Clean up multi-modal processor (#11207)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
a1c02058
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
38 deletions
+32
-38
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+1
-4
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+8
-9
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+23
-25
No files found.
examples/offline_inference_vision_language.py
View file @
b10609e6
...
...
@@ -92,10 +92,7 @@ def run_fuyu(question: str, modality: str):
def
run_phi3v
(
question
:
str
,
modality
:
str
):
assert
modality
==
"image"
prompt
=
f
"<|user|>
\n
<|image_1|>
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
# noqa: E501
# Note: The default setting of max_num_seqs (256) and
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
prompt
=
f
"<|user|>
\n
<|image_1|>
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
...
...
tests/multimodal/test_processing.py
View file @
b10609e6
...
...
@@ -2,10 +2,9 @@ from typing import cast
import
pytest
from
vllm.multimodal.processing
import
(
MultiModalDataItems
,
PromptReplacement
,
_PlaceholderInfo
,
find_text_matches
,
find_token_matches
,
iter_placeholders
,
iter_token_matches
,
from
vllm.multimodal.processing
import
(
PromptReplacement
,
_PlaceholderInfo
,
find_text_matches
,
find_token_matches
,
iter_placeholders
,
iter_token_matches
,
replace_text_matches
,
replace_token_matches
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
...
...
@@ -314,8 +313,8 @@ def test_find_replace_text(
result
=
replace_text_matches
(
prompt
,
matches
,
MultiModalDataItems
({
key
:
[
None
]
*
mm_count
for
key
in
repl_by_key
}
)
,
{
key
:
mm_count
for
key
in
repl_by_key
},
)
# Only displayed on error
...
...
@@ -380,8 +379,8 @@ def test_find_replace_tokens(
result
=
replace_token_matches
(
prompt
,
matches
,
MultiModalDataItems
({
key
:
[
None
]
*
mm_count
for
key
in
repl_by_key
}
)
,
{
key
:
mm_count
for
key
in
repl_by_key
},
)
# Only displayed on error
...
...
@@ -476,7 +475,7 @@ def test_iter_placeholders(
prompt_repls
,
prompt
,
# Effectively match all occurrences in the prompt
MultiModalDataItems
({
key
:
[
None
]
*
3
for
key
in
repl_by_key
}
)
,
{
key
:
3
for
key
in
repl_by_key
},
))
# Only displayed on error
...
...
vllm/multimodal/processing.py
View file @
b10609e6
...
...
@@ -403,18 +403,17 @@ def _resolve_matches(
def
_replace_matches
(
prompt
:
_S
,
matches
:
Sequence
[
_PromptReplacementMatch
],
mm_item
s
:
MultiModalDataItems
,
mm_item
_counts
:
Mapping
[
str
,
int
]
,
)
->
list
[
_S
]:
out_seqs
=
list
[
_S
]()
prev_end_idx
=
0
next_idx_by_modality
=
{
modality
:
0
for
modality
in
mm_items
}
next_idx_by_modality
=
{
modality
:
0
for
modality
in
mm_item
_count
s
}
for
match
in
_resolve_matches
(
prompt
,
matches
):
modality
=
match
.
modality
modal_items
=
mm_items
[
modality
]
item_idx
=
next_idx_by_modality
[
modality
]
if
item_idx
>=
len
(
modal
_
it
ems
)
:
if
item_idx
>=
mm_item_counts
[
modalit
y
]
:
continue
start_idx
=
match
.
start_idx
...
...
@@ -441,13 +440,13 @@ def _replace_matches(
def
replace_token_matches
(
prompt
:
list
[
int
],
matches
:
Sequence
[
_PromptReplacementTokenMatch
],
mm_item
s
:
MultiModalDataItems
,
mm_item
_counts
:
Mapping
[
str
,
int
]
,
)
->
list
[
int
]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if
not
matches
:
return
prompt
token_id_seqs
=
_replace_matches
(
prompt
,
matches
,
mm_items
)
token_id_seqs
=
_replace_matches
(
prompt
,
matches
,
mm_item
_count
s
)
return
flatten_2d_lists
(
token_id_seqs
)
...
...
@@ -455,13 +454,13 @@ def replace_token_matches(
def
replace_text_matches
(
prompt
:
str
,
matches
:
Sequence
[
_PromptReplacementTextMatch
],
mm_item
s
:
MultiModalDataItems
,
mm_item
_counts
:
Mapping
[
str
,
int
]
,
)
->
str
:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if
not
matches
:
return
prompt
texts
=
_replace_matches
(
prompt
,
matches
,
mm_items
)
texts
=
_replace_matches
(
prompt
,
matches
,
mm_item
_count
s
)
return
""
.
join
(
texts
)
...
...
@@ -470,9 +469,9 @@ def _iter_modality_placeholders(
prompt
:
list
[
int
],
modality
:
str
,
modality_repls
:
Sequence
[
_BoundPromptReplacement
],
modal_item
s
:
list
[
Any
]
,
modal_item
_count
:
int
,
)
->
Iterable
[
_PlaceholderInfo
]:
if
len
(
modal_item
s
)
==
0
:
if
modal_item
_count
==
0
:
return
prompt_len
=
len
(
prompt
)
...
...
@@ -499,7 +498,7 @@ def _iter_modality_placeholders(
)
item_index
+=
1
if
item_index
>=
len
(
modal_item
s
)
:
if
item_index
>=
modal_item
_count
:
return
# Exclude overlapping matches
...
...
@@ -514,7 +513,7 @@ def _iter_modality_placeholders(
def
iter_placeholders
(
prompt_repls
:
Sequence
[
_BoundPromptReplacement
],
prompt
:
list
[
int
],
mm_item
s
:
MultiModalDataItems
,
mm_item
_counts
:
Mapping
[
str
,
int
]
,
)
->
Iterable
[
_PlaceholderInfo
]:
"""
Yield each set of placeholder tokens found in :code:`prompt`.
...
...
@@ -523,13 +522,13 @@ def iter_placeholders(
"""
repls_by_modality
=
dict
(
full_groupby_modality
(
prompt_repls
))
for
modality
,
modal_item
s
in
mm_items
.
items
():
for
modality
,
modal_item
_count
in
mm_item
_count
s
.
items
():
if
modality
in
repls_by_modality
:
yield
from
_iter_modality_placeholders
(
prompt
,
modality
,
repls_by_modality
[
modality
],
modal_item
s
,
modal_item
_count
,
)
...
...
@@ -590,10 +589,10 @@ class BaseMultiModalProcessor(ABC):
self
,
all_prompt_repls
:
Sequence
[
_BoundPromptReplacement
],
new_token_ids
:
list
[
int
],
mm_item
s
:
MultiModalDataItems
,
mm_item
_counts
:
Mapping
[
str
,
int
]
,
)
->
list
[
_PlaceholderInfo
]:
return
list
(
iter_placeholders
(
all_prompt_repls
,
new_token_ids
,
mm_items
))
iter_placeholders
(
all_prompt_repls
,
new_token_ids
,
mm_item
_count
s
))
def
_apply_hf_processor
(
self
,
...
...
@@ -655,10 +654,9 @@ class BaseMultiModalProcessor(ABC):
def
_apply_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_inputs
:
BatchFeature
,
token_ids
:
list
[
int
],
prompt_repls
:
Sequence
[
_BoundPromptReplacement
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
list
[
_PlaceholderInfo
]]:
tokenizer
=
self
.
_get_tokenizer
()
...
...
@@ -675,13 +673,13 @@ class BaseMultiModalProcessor(ABC):
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if
all
(
len
(
matches
)
>=
len
(
mm_items
[
modality
]
)
len
(
matches
)
>=
mm_item
_count
s
[
modality
]
for
modality
,
matches
in
full_groupby_modality
(
token_matches
)
):
# yapf: disable
token_ids
=
replace_token_matches
(
token_ids
,
token_matches
,
mm_items
,
mm_item
_count
s
,
)
text
=
_decode
(
tokenizer
,
token_ids
)
...
...
@@ -693,14 +691,14 @@ class BaseMultiModalProcessor(ABC):
text
=
replace_text_matches
(
text
,
text_matches
,
mm_items
,
mm_item
_count
s
,
)
token_ids
=
_encode
(
tokenizer
,
text
)
matched_repls
=
[
match
.
prompt_repl
for
match
in
text_matches
]
placeholders
=
self
.
_find_placeholders
(
matched_repls
,
token_ids
,
mm_items
)
mm_item
_count
s
)
return
token_ids
,
text
,
placeholders
...
...
@@ -737,8 +735,9 @@ class BaseMultiModalProcessor(ABC):
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts
=
{
m
:
len
(
items
)
for
m
,
items
in
mm_items
.
items
()}
all_placeholders
=
self
.
_find_placeholders
(
all_prompt_repls
,
prompt_ids
,
mm_items
)
prompt_ids
,
mm_item
_count
s
)
if
all_placeholders
:
prompt_text
=
_decode
(
tokenizer
,
prompt_ids
)
...
...
@@ -748,10 +747,9 @@ class BaseMultiModalProcessor(ABC):
prompt_text
,
all_placeholders
,
)
=
self
.
_apply_prompt_replacements
(
mm_items
,
hf_inputs
,
prompt_ids
,
all_prompt_repls
,
mm_item_counts
,
)
mm_placeholders
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment