Unverified Commit aa46953a authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Misc][VLM][Doc] Consolidate offline examples for vision language models (#6858)


Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 593e79e7
import requests
from PIL import Image
from vllm import LLM, SamplingParams
def run_fuyu():
llm = LLM(model="adept/fuyu-8b", max_model_len=4096)
# single-image prompt
prompt = "What is the highest life expectancy at of male?\n"
url = "https://huggingface.co/adept/fuyu-8b/resolve/main/chart.png"
image = Image.open(requests.get(url, stream=True).raw)
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image
},
},
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
run_fuyu()
from vllm import LLM
from vllm.assets.image import ImageAsset
def run_llava():
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
image = ImageAsset("stop_sign").pil_image
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
run_llava()
from io import BytesIO
import requests
from PIL import Image
from vllm import LLM, SamplingParams
def run_llava_next():
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=4096)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=100)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)
generated_text = ""
for o in outputs:
generated_text += o.outputs[0].text
print(f"LLM output:{generated_text}")
if __name__ == "__main__":
run_llava_next()
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
# 2.0
# The official repo doesn't work yet, so we need to use a fork for now
# For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
# MODEL_NAME = "HwwwH/MiniCPM-V-2"
# 2.5
MODEL_NAME = "openbmb/MiniCPM-Llama3-V-2_5"
image = ImageAsset("stop_sign").pil_image.convert("RGB")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
llm = LLM(model=MODEL_NAME,
gpu_memory_utilization=1,
trust_remote_code=True,
max_model_len=4096)
messages = [{
'role':
'user',
'content':
'(<image>./</image>)\n' + "What's the content of the image?"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# 2.0
# stop_token_ids = [tokenizer.eos_id]
# 2.5
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
sampling_params = SamplingParams(
stop_token_ids=stop_token_ids,
# temperature=0.7,
# top_p=0.8,
# top_k=100,
# seed=3472,
max_tokens=1024,
# min_tokens=150,
temperature=0,
use_beam_search=True,
# length_penalty=1.2,
best_of=3)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.utils import FlexibleArgumentParser
# Input image and question
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
question = "What is the content of this image?"
# LLaVA-1.5
def run_llava(question):
prompt = f"USER: <image>\n{question}\nASSISTANT:"
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
return llm, prompt
# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question):
prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
return llm, prompt
# Fuyu
def run_fuyu(question):
prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b")
return llm, prompt
# Phi-3-Vision
def run_phi3v(question):
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
# Note: The default setting of max_num_seqs (256) and
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_num_seqs=5,
)
return llm, prompt
# PaliGemma
def run_paligemma(question):
prompt = question
llm = LLM(model="google/paligemma-3b-mix-224")
return llm, prompt
# Chameleon
def run_chameleon(question):
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b")
return llm, prompt
# MiniCPM-V
def run_minicpmv(question):
# 2.0
# The official repo doesn't work yet, so we need to use a fork for now
# For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
# model_name = "HwwwH/MiniCPM-V-2"
# 2.5
model_name = "openbmb/MiniCPM-Llama3-V-2_5"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
llm = LLM(
model=model_name,
trust_remote_code=True,
)
messages = [{
'role': 'user',
'content': f'(<image>./</image>)\n{question}'
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
return llm, prompt
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
"fuyu": run_fuyu,
"phi3_v": run_phi3v,
"paligemma": run_paligemma,
"chameleon": run_chameleon,
"minicpmv": run_minicpmv,
}
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
llm, prompt = model_example_map[model](question)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"image": image
},
}
else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": image
},
} for _ in range(args.num_prompts)]
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models')
args = parser.parse_args()
parser.add_argument('--model-type',
'-m',
type=str,
default="llava",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')
args = parser.parse_args()
main(args)
from vllm import LLM
from vllm.assets.image import ImageAsset
def run_paligemma():
llm = LLM(model="google/paligemma-3b-mix-224")
prompt = "caption es"
image = ImageAsset("stop_sign").pil_image
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
run_paligemma()
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"
# Note: The default setting of max_num_seqs (256) and
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
llm = LLM(
model=model_path,
trust_remote_code=True,
max_num_seqs=5,
)
image = ImageAsset("cherry_blossom").pil_image
# single-image prompt
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image
},
},
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
run_phi3v()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment