Unverified Commit 9eb49e87 authored by XinyuanTong's avatar XinyuanTong Committed by GitHub
Browse files

[VLM RLHF] Take Image input for verl vlm rollout (#4915)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarGeLee <leege233@gmail.com>
parent 12047f5e
......@@ -151,10 +151,6 @@ class Engine:
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
if image_data is not None:
modalities_list.append("image")
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
......@@ -165,7 +161,6 @@ class Engine:
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
......
......@@ -139,8 +139,6 @@ class BaseMultimodalProcessor(ABC):
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
assert isinstance(prompt, str)
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
......@@ -204,7 +202,16 @@ class BaseMultimodalProcessor(ABC):
continue
image_sizes += frames[0].size * len(frames)
hashes += [hash(image_file)] * len(frames)
# Generate a hashable value for the image file
if isinstance(image_file, Image.Image):
# For PIL.Image objects, use the ID as a hashable value
hash_value = hash(id(image_file))
else:
# For other types (strings, etc.), use the regular hash
hash_value = hash(image_file)
hashes += [hash_value] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
......
......@@ -5,7 +5,7 @@ from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.multimodal_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
......
......@@ -566,10 +566,14 @@ def encode_video(video_path, frame_count_limit=None):
return frames
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
def load_image(
image_file: Union[Image.Image, str, bytes]
) -> tuple[Image.Image, tuple[int, int]]:
image = image_size = None
if isinstance(image_file, bytes):
if isinstance(image_file, Image.Image):
image = image_file
image_size = (image.width, image.height)
elif isinstance(image_file, bytes):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment