Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8ca5051b
Unverified
Commit
8ca5051b
authored
Sep 22, 2024
by
Alex Brooks
Committed by
GitHub
Sep 22, 2024
Browse files
[Misc] Use NamedTuple in Multi-image example (#8705)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
06ed2815
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
22 deletions
+52
-22
examples/offline_inference_vision_language_multi_image.py
examples/offline_inference_vision_language_multi_image.py
+52
-22
No files found.
examples/offline_inference_vision_language_multi_image.py
View file @
8ca5051b
...
@@ -4,8 +4,9 @@ multi-image input on vision language models, using the chat template defined
...
@@ -4,8 +4,9 @@ multi-image input on vision language models, using the chat template defined
by the model.
by the model.
"""
"""
from
argparse
import
Namespace
from
argparse
import
Namespace
from
typing
import
List
from
typing
import
List
,
NamedTuple
,
Optional
from
PIL.Image
import
Image
from
transformers
import
AutoProcessor
,
AutoTokenizer
from
transformers
import
AutoProcessor
,
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
...
@@ -19,7 +20,15 @@ IMAGE_URLS = [
...
@@ -19,7 +20,15 @@ IMAGE_URLS = [
]
]
def
load_qwenvl_chat
(
question
:
str
,
image_urls
:
List
[
str
]):
class
ModelRequestData
(
NamedTuple
):
llm
:
LLM
prompt
:
str
stop_token_ids
:
Optional
[
List
[
str
]]
image_data
:
List
[
Image
]
chat_template
:
Optional
[
str
]
def
load_qwenvl_chat
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
model_name
=
"Qwen/Qwen-VL-Chat"
model_name
=
"Qwen/Qwen-VL-Chat"
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
...
@@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
...
@@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
]
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
return
llm
,
prompt
,
stop_token_ids
,
None
,
chat_template
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
chat_template
,
)
def
load_phi3v
(
question
:
str
,
image_urls
:
List
[
str
]):
def
load_phi3v
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
llm
=
LLM
(
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
model
=
"microsoft/Phi-3.5-vision-instruct"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
...
@@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
...
@@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
for
i
,
_
in
enumerate
(
image_urls
,
start
=
1
))
for
i
,
_
in
enumerate
(
image_urls
,
start
=
1
))
prompt
=
f
"<|user|>
\n
{
placeholders
}
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
prompt
=
f
"<|user|>
\n
{
placeholders
}
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
,
None
,
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
None
,
)
def
load_internvl
(
question
:
str
,
image_urls
:
List
[
str
]):
def
load_internvl
(
question
:
str
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
model_name
=
"OpenGVLab/InternVL2-2B"
model_name
=
"OpenGVLab/InternVL2-2B"
llm
=
LLM
(
llm
=
LLM
(
...
@@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
...
@@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
,
"<|end|>"
]
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
,
"<|end|>"
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
return
llm
,
prompt
,
stop_token_ids
,
None
,
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
None
,
)
def
load_qwen2_vl
(
question
,
image_urls
:
List
[
str
]):
def
load_qwen2_vl
(
question
,
image_urls
:
List
[
str
])
->
ModelRequestData
:
try
:
try
:
from
qwen_vl_utils
import
process_vision_info
from
qwen_vl_utils
import
process_vision_info
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
...
@@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
...
@@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
else
:
else
:
image_data
,
_
=
process_vision_info
(
messages
)
image_data
,
_
=
process_vision_info
(
messages
)
return
llm
,
prompt
,
stop_token_ids
,
image_data
,
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
image_data
,
chat_template
=
None
,
)
model_example_map
=
{
model_example_map
=
{
...
@@ -155,20 +189,17 @@ model_example_map = {
...
@@ -155,20 +189,17 @@ model_example_map = {
def
run_generate
(
model
,
question
:
str
,
image_urls
:
List
[
str
]):
def
run_generate
(
model
,
question
:
str
,
image_urls
:
List
[
str
]):
llm
,
prompt
,
stop_token_ids
,
image_data
,
_
=
model_example_map
[
model
](
req_data
=
model_example_map
[
model
](
question
,
image_urls
)
question
,
image_urls
)
if
image_data
is
None
:
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
]
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
128
,
max_tokens
=
128
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
req_data
.
stop_token_ids
)
outputs
=
llm
.
generate
(
outputs
=
req_data
.
llm
.
generate
(
{
{
"prompt"
:
prompt
,
"prompt"
:
req_data
.
prompt
,
"multi_modal_data"
:
{
"multi_modal_data"
:
{
"image"
:
image_data
"image"
:
req_data
.
image_data
},
},
},
},
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
)
...
@@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
...
@@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
def
run_chat
(
model
:
str
,
question
:
str
,
image_urls
:
List
[
str
]):
def
run_chat
(
model
:
str
,
question
:
str
,
image_urls
:
List
[
str
]):
llm
,
_
,
stop_token_ids
,
_
,
chat_template
=
model_example_map
[
model
](
req_data
=
model_example_map
[
model
](
question
,
image_urls
)
question
,
image_urls
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
128
,
max_tokens
=
128
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
req_data
.
stop_token_ids
)
outputs
=
llm
.
chat
(
outputs
=
req_data
.
llm
.
chat
(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
...
@@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
...
@@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
],
],
}],
}],
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
chat_template
=
chat_template
,
chat_template
=
req_data
.
chat_template
,
)
)
for
o
in
outputs
:
for
o
in
outputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment