Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8ca5051b
Unverified
Commit
8ca5051b
authored
Sep 22, 2024
by
Alex Brooks
Committed by
GitHub
Sep 22, 2024
Browse files
[Misc] Use NamedTuple in Multi-image example (#8705)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
06ed2815
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
22 deletions
+52
-22
examples/offline_inference_vision_language_multi_image.py
examples/offline_inference_vision_language_multi_image.py
+52
-22
No files found.
examples/offline_inference_vision_language_multi_image.py
View file @
8ca5051b
...
...
@@ -4,8 +4,9 @@ multi-image input on vision language models, using the chat template defined
by the model.
"""
from
argparse
import
Namespace
from
typing
import
List
from
typing
import
List
,
NamedTuple
,
Optional
from
PIL.Image
import
Image
from
transformers
import
AutoProcessor
,
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
...
...
@@ -19,7 +20,15 @@ IMAGE_URLS = [
]
def
load_qwenvl_chat
(
question
:
str
,
image_urls
:
List
[
str
]):
class
ModelRequestData
(
NamedTuple
):
llm
:
LLM
prompt
:
str
stop_token_ids
:
Optional
[
List
[
str
]]
image_data
:
List
[
Image
]
chat_template
:
Optional
[
str
]
def
load_qwenvl_chat
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
model_name
=
"Qwen/Qwen-VL-Chat"
llm
=
LLM
(
model
=
model_name
,
...
...
@@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
return
llm
,
prompt
,
stop_token_ids
,
None
,
chat_template
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
chat_template
,
)
def
load_phi3v
(
question
:
str
,
image_urls
:
List
[
str
]):
def
load_phi3v
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
trust_remote_code
=
True
,
...
...
@@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
for
i
,
_
in
enumerate
(
image_urls
,
start
=
1
))
prompt
=
f
"<|user|>
\n
{
placeholders
}
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
,
None
,
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
None
,
)
def
load_internvl
(
question
:
str
,
image_urls
:
List
[
str
]):
def
load_internvl
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
model_name
=
"OpenGVLab/InternVL2-2B"
llm
=
LLM
(
...
...
@@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
,
"<|end|>"
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
return
llm
,
prompt
,
stop_token_ids
,
None
,
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
None
,
)
def
load_qwen2_vl
(
question
,
image_urls
:
List
[
str
]):
def
load_qwen2_vl
(
question
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
try
:
from
qwen_vl_utils
import
process_vision_info
except
ModuleNotFoundError
:
...
...
@@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
else
:
image_data
,
_
=
process_vision_info
(
messages
)
return
llm
,
prompt
,
stop_token_ids
,
image_data
,
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
image_data
,
chat_template
=
None
,
)
model_example_map
=
{
...
...
@@ -155,20 +189,17 @@ model_example_map = {
def
run_generate
(
model
,
question
:
str
,
image_urls
:
List
[
str
]):
llm
,
prompt
,
stop_token_ids
,
image_data
,
_
=
model_example_map
[
model
](
question
,
image_urls
)
if
image_data
is
None
:
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
]
req_data
=
model_example_map
[
model
](
question
,
image_urls
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
128
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
req_data
.
stop_token_ids
)
outputs
=
llm
.
generate
(
outputs
=
req_data
.
llm
.
generate
(
{
"prompt"
:
prompt
,
"prompt"
:
req_data
.
prompt
,
"multi_modal_data"
:
{
"image"
:
image_data
"image"
:
req_data
.
image_data
},
},
sampling_params
=
sampling_params
)
...
...
@@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
def
run_chat
(
model
:
str
,
question
:
str
,
image_urls
:
List
[
str
]):
llm
,
_
,
stop_token_ids
,
_
,
chat_template
=
model_example_map
[
model
](
question
,
image_urls
)
req_data
=
model_example_map
[
model
](
question
,
image_urls
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
128
,
stop_token_ids
=
stop_token_ids
)
outputs
=
llm
.
chat
(
stop_token_ids
=
req_data
.
stop_token_ids
)
outputs
=
req_data
.
llm
.
chat
(
[{
"role"
:
"user"
,
...
...
@@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
],
}],
sampling_params
=
sampling_params
,
chat_template
=
chat_template
,
chat_template
=
req_data
.
chat_template
,
)
for
o
in
outputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment