Unverified Commit e7beff8a authored by XinyuanTong's avatar XinyuanTong Committed by GitHub
Browse files

fix: examples for token_in_token_out_vlm (#5193)

parent 4d2e3051
import argparse import argparse
import dataclasses import dataclasses
from io import BytesIO
from typing import Tuple from typing import Tuple
import requests
from PIL import Image
from transformers import AutoProcessor from transformers import AutoProcessor
from sglang import Engine from sglang import Engine
...@@ -19,20 +16,22 @@ def get_input_ids( ...@@ -19,20 +16,22 @@ def get_input_ids(
) -> Tuple[list[int], list]: ) -> Tuple[list[int], list]:
chat_template = get_chat_template_by_model_path(model_config.model_path) chat_template = get_chat_template_by_model_path(model_config.model_path)
text = f"{chat_template.image_token}What is in this picture?" text = f"{chat_template.image_token}What is in this picture?"
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
image_data = [DEFAULT_IMAGE_URL] image_data = [DEFAULT_IMAGE_URL]
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
model_config.model_path, trust_remote_code=server_args.trust_remote_code model_config.model_path, trust_remote_code=server_args.trust_remote_code
) )
inputs = processor( input_ids = (
text=[text], processor.tokenizer(
images=images, text=[text],
return_tensors="pt", return_tensors="pt",
)
.input_ids[0]
.tolist()
) )
return inputs.input_ids[0].tolist(), image_data return input_ids, image_data
def token_in_out_example( def token_in_out_example(
......
...@@ -5,11 +5,9 @@ python token_in_token_out_vlm_server.py ...@@ -5,11 +5,9 @@ python token_in_token_out_vlm_server.py
""" """
from io import BytesIO
from typing import Tuple from typing import Tuple
import requests import requests
from PIL import Image
from transformers import AutoProcessor from transformers import AutoProcessor
from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.chat_template import get_chat_template_by_model_path
...@@ -28,18 +26,20 @@ MODEL_PATH = "Qwen/Qwen2-VL-2B" ...@@ -28,18 +26,20 @@ MODEL_PATH = "Qwen/Qwen2-VL-2B"
def get_input_ids() -> Tuple[list[int], list]: def get_input_ids() -> Tuple[list[int], list]:
chat_template = get_chat_template_by_model_path(MODEL_PATH) chat_template = get_chat_template_by_model_path(MODEL_PATH)
text = f"{chat_template.image_token}What is in this picture?" text = f"{chat_template.image_token}What is in this picture?"
images = [Image.open(BytesIO(requests.get(DEFAULT_IMAGE_URL).content))]
image_data = [DEFAULT_IMAGE_URL] image_data = [DEFAULT_IMAGE_URL]
processor = AutoProcessor.from_pretrained(MODEL_PATH) processor = AutoProcessor.from_pretrained(MODEL_PATH)
inputs = processor( input_ids = (
text=[text], processor.tokenizer(
images=images, text=[text],
return_tensors="pt", return_tensors="pt",
)
.input_ids[0]
.tolist()
) )
return inputs.input_ids[0].tolist(), image_data return input_ids, image_data
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment