"src/vscode:/vscode.git/clone" did not exist on "fac75e166b4a4a7f84ac5d12c3b8f4ba01cda57b"
Unverified Commit a5b14ad0 authored by Kaichen Zhang - NTU's avatar Kaichen Zhang - NTU Committed by GitHub
Browse files

[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2...


[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2 decoder (3) openai api compatible server. (#1123)
Co-authored-by: default avatarBo Li <drluodian@gmail.com>
parent 5fafcac0
......@@ -231,8 +231,13 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 30000`
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --host=127.0.0.1 --tp-size=1 --chat-template=llava_llama_3`
- `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --host="127.0.0.1" --tp-size=8 --chat-template=chatml-llava`
- LLaVA-NeXT-Video
- see [examples/usage/llava_video](examples/usage/llava_video)
- [LLaVA-OneVision](https://arxiv.org/abs/2408.03326)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384`
- see [test/srt/test_llava_onevision_openai_server.py](test/srt/test_llava_onevision_openai_server.py)
- Yi-VL
- see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
- StableLM
......
import base64
import io
import os
import sys
import time
import numpy as np
import openai
import requests
from decord import VideoReader, cpu
from PIL import Image
# pip install httpx==0.23.3
# pip install decord
# pip install protobuf==3.20.0
def download_video(url, cache_dir):
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
print(f"File downloaded and saved to: {file_path}")
return file_path
def create_openai_client(base_url):
return openai.Client(api_key="EMPTY", base_url=base_url)
def image_stream_request_test(client):
print("----------------------Image Stream Request Test----------------------")
stream_request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
},
{
"type": "text",
"text": "Please describe this image. Please list the benchmarks and the models.",
},
],
},
],
temperature=0.7,
max_tokens=1024,
stream=True,
)
stream_response = ""
for chunk in stream_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
stream_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def video_stream_request_test(client, video_path):
print("------------------------Video Stream Request Test----------------------")
messages = prepare_video_messages(video_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
stream=True,
)
print("-" * 30)
video_response = ""
for chunk in video_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
video_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def image_speed_test(client):
print("----------------------Image Speed Test----------------------")
start_time = time.time()
request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
},
{
"type": "text",
"text": "Please describe this image. Please list the benchmarks and the models.",
},
],
},
],
temperature=0,
max_tokens=1024,
)
end_time = time.time()
response = request.choices[0].message.content
print(response)
print("-" * 30)
print_speed_test_results(request, start_time, end_time)
def video_speed_test(client, video_path):
print("------------------------Video Speed Test------------------------")
messages = prepare_video_messages(video_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
)
end_time = time.time()
video_response = video_request.choices[0].message.content
print(video_response)
print("-" * 30)
print_speed_test_results(video_request, start_time, end_time)
def prepare_video_messages(video_path):
max_frames_num = 32
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, max_frames_num, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
base64_frames = []
for frame in frames:
pil_img = Image.fromarray(frame)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
base64_frames.append(base64_str)
messages = [{"role": "user", "content": []}]
frame_format = {
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"},
}
for base64_frame in base64_frames:
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
base64_frame
)
messages[0]["content"].append(frame_format.copy())
prompt = {"type": "text", "text": "Please describe the video in detail."}
messages[0]["content"].append(prompt)
return messages
def print_speed_test_results(request, start_time, end_time):
total_tokens = request.usage.total_tokens
completion_tokens = request.usage.completion_tokens
prompt_tokens = request.usage.prompt_tokens
print(f"Total tokens: {total_tokens}")
print(f"Completion tokens: {completion_tokens}")
print(f"Prompt tokens: {prompt_tokens}")
print(f"Time taken: {end_time - start_time} seconds")
print(f"Token per second: {total_tokens / (end_time - start_time)}")
print(f"Completion token per second: {completion_tokens / (end_time - start_time)}")
print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}")
def main():
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
video_path = download_video(url, cache_dir)
client = create_openai_client("http://127.0.0.1:30000/v1")
image_stream_request_test(client)
video_stream_request_test(client, video_path)
image_speed_test(client)
video_speed_test(client, video_path)
if __name__ == "__main__":
main()
......@@ -121,6 +121,20 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
if __name__ == "__main__":
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad responses
with open(file_path, "wb") as f:
f.write(response.content)
print(f"File downloaded and saved to: {file_path}")
# Create the parser
parser = argparse.ArgumentParser(
description="Run video processing with specified port."
......@@ -148,7 +162,7 @@ if __name__ == "__main__":
parser.add_argument(
"--video-dir",
type=str,
default="./videos/Q98Z4OTh8RwmDonc.mp4",
default=os.path.expanduser("~/.cache/jobs.mp4"),
help="The directory or path for the processed video files.",
)
parser.add_argument(
......
......@@ -20,7 +20,7 @@ dependencies = [
]
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"vllm==0.5.4", "outlines>=0.0.44"]
......
......@@ -137,7 +137,7 @@ register_chat_template(
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
......@@ -145,7 +145,7 @@ register_chat_template(
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
image_token="<image>\n",
)
)
......@@ -322,12 +322,17 @@ def match_chat_ml(model_path: str):
if "tinyllama" in model_path:
return get_chat_template("chatml")
# Now the suffix for qwen2 chat model is "instruct"
if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
if (
"qwen" in model_path
and ("chat" in model_path or "instruct" in model_path)
and ("llava" not in model_path)
):
return get_chat_template("qwen")
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return get_chat_template("chatml-llava")
......
......@@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum):
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
LLAMA3 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
......@@ -137,6 +138,20 @@ class Conversation:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA3:
ret = "<|begin_of_text|>"
if self.system_message:
ret += system_prompt
else:
ret += ""
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
ret += f"{message.strip()}<|eot_id|>"
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
# print(ret)
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
......@@ -379,12 +394,23 @@ def generate_chat_conv(
conv.append_message(conv.roles[0], message.content)
else:
real_content = ""
# calculate number of image_url
num_image_url = 0
for content in message.content:
if content.type == "image_url":
num_image_url += 1
if num_image_url > 1:
image_token = "<image>"
else:
image_token = "<image>\n"
for content in message.content:
if content.type == "text":
if num_image_url > 16:
real_content += "\n" # for video
real_content += content.text
elif content.type == "image_url":
# NOTE: Only works for llava
real_content += "<image>\n"
real_content += image_token
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
......@@ -425,6 +451,18 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="chatml-llava",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
stop_str=["<|endoftext|>", "<|im_end|>"],
)
)
register_conv_template(
Conversation(
name="vicuna_v1.1",
......@@ -437,6 +475,17 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="llava_llama_3",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
)
)
# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
register_conv_template(
Conversation(
......
......@@ -131,11 +131,49 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
async def get_pixel_values(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
......@@ -194,8 +232,8 @@ class TokenizerManager:
)
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
......@@ -704,7 +742,7 @@ def get_pixel_values(
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
......
......@@ -322,11 +322,16 @@ class ModelTpServer:
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
image_hash = (
hash(tuple(recv_req.image_hash))
if isinstance(recv_req.image_hash, list)
else recv_req.image_hash
)
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
(
......
......@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
"""
Utilities for multi-modal models.
This python file mainly contains utilities that were used in the
image processing logic of llava-next including operations such as
anyres and anyres_max
Currently supports the anyres and anyres_max operation for CLIP and
SigLip. For more information, you may refer to the paper or the blog
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
"""
import ast
import base64
import math
import re
from io import BytesIO
import numpy as np
......@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
# Calculate effective and wasted resolutions
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
......@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
......@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
Returns:
np.array: An np array containing the processed image patches.
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
try:
patch_size = processor.size[0]
except Exception as e:
patch_size = processor.size["shortest_edge"]
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
......@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
image_original_resize = image.resize(
(processor.size["shortest_edge"], processor.size["shortest_edge"])
# For Siglip processor, only have size but no crop size
crop_size = (
processor.crop_size["height"]
if "crop_size" in processor.__dict__
else processor.size["height"]
)
shortest_edge = (
processor.size["shortest_edge"]
if "shortest_edge" in processor.size
else processor.size["height"]
)
patches = divide_to_patches(image_padded, crop_size)
image_original_resize = image.resize((shortest_edge, shortest_edge))
image_patches = [image_original_resize] + patches
image_patches = [
processor.preprocess(image_patch)["pixel_values"][0]
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
for image_patch in image_patches
]
return np.stack(image_patches, axis=0)
......@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
)
image = image_processor.preprocess(image)["pixel_values"][0]
new_images.append(image)
elif image_aspect_ratio == "anyres":
elif "anyres" in image_aspect_ratio:
for image in images:
image = process_anyres_image(
image, image_processor, model_cfg.image_grid_pinpoints
......
......@@ -88,14 +88,19 @@ class InputMetadata:
reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(
(r.image_offset - batch.prefix_lens_cpu[i])
if r.image_offset is not None
else 0
)
for i, r in enumerate(reqs)
]
self.image_offsets = []
for r in reqs:
if isinstance(r.image_offset, list):
self.image_offsets.append(
[
(image_offset - len(r.prefix_indices))
for image_offset in r.image_offset
]
)
elif isinstance(r.image_offset, int):
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
elif r.image_offset is None:
self.image_offsets.append(0)
def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets
......
......@@ -15,6 +15,8 @@ limitations under the License.
"""Inference-only LLaVa model compatible with HuggingFace weights."""
import math
import re
from typing import Iterable, List, Optional, Tuple
import numpy as np
......@@ -26,6 +28,8 @@ from transformers import (
LlavaConfig,
MistralConfig,
Qwen2Config,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.config import CacheConfig
......@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
# hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
offset_list = []
for image_s in image_size:
if len(image_size) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = (
math.ceil(self.image_size / self.patch_size / 2) ** 2
)
else:
new_image_feature_len = self.image_feature_len # multiimage
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
if "anyres" in image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_s,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_s)
if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r"anyres_max_(\d+)", self.config.image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
# times = math.sqrt(h * w / (max_num_patches * unit**2))
times = math.sqrt(
new_h * new_w / (max_num_patches * self.image_feature_len)
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
if times > 1.1:
new_h = int(new_h // times)
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
return input_ids, offset_list
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
......@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
......@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
if len(image_sizes[image_idx]) == 1:
image_aspect_ratio = (
self.config.image_aspect_ratio
) # single image
else:
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if self.image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
r"anyres_max_(\d+)", image_aspect_ratio
)
if matched_anyres_max_num_patches:
max_num_patches = int(
matched_anyres_max_num_patches.group(1)
)
if (
image_aspect_ratio == "anyres"
or "anyres_max" in image_aspect_ratio
):
vision_tower_image_size = self.image_size
try:
num_patch_width, num_patch_height = (
get_anyres_image_grid_shape(
image_sizes[image_idx][0],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
)
except Exception as e:
print(f"Error: {e}")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError()
image_feature = image_feature.view(
2, 2, height, width, -1
)
# (
# num_patch_width,
# num_patch_height,
# ) = get_anyres_image_grid_shape(
# image_sizes[image_idx][0],
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# image_feature = image_feature.view(
# num_patch_height, num_patch_width, height, width, -1
# )
if "unpad" in self.mm_patch_merge_type:
unit = image_feature.shape[2]
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
......@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
2, 3
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
image_feature, image_sizes[image_idx][0]
)
if (
"anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
):
c, h, w = image_feature.shape
times = math.sqrt(
h * w / (max_num_patches * unit**2)
)
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode="bilinear",
)[0]
image_feature = torch.cat(
(
image_feature,
......@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
image_feature = image_feature.unsqueeze(0)
else:
image_feature = image_feature[0]
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[None],
),
dim=0,
if image_feature.shape[0] > 16: # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(
num_of_frames, height, width, -1
)
image_feature = image_feature.permute(
0, 3, 1, 2
).contiguous() # N, C, H, W
height, weight = image_feature.shape[2:]
scaled_shape = [
math.ceil(height / 2),
math.ceil(weight / 2),
]
image_feature = nn.functional.interpolate(
image_feature, size=scaled_shape, mode="bilinear"
)
image_feature = (
image_feature.flatten(2)
.transpose(1, 2)
.contiguous()
) # N, C, H*W
new_image_features.append(image_feature)
image_features = new_image_features
......@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
pad_dim = image_features[pt].shape[-1] # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
for j, image_off in enumerate(image_offsets[i]):
# print("actual image_features length: ", image_features[pt][j].shape[0])
pad_len = image_features[pt][j].shape[0]
input_embeds[
start_idx + image_off : start_idx + image_off + pad_len
] = image_features[pt][j]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(image_features[pt].shape)
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
......@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
if "clip" in vision_path:
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
elif "siglip" in vision_path:
self.vision_tower = SiglipVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
# Siglip needs all feature tokens
self.config.mm_vision_select_feature = "full"
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
......@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
if (
self.vision_feature_select_strategy == "patch"
or self.vision_feature_select_strategy == "full"
):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
......
import base64
import io
import json
import os
import sys
import time
import unittest
import numpy as np
import openai
import requests
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import DEFAULT_URL_FOR_UNIT_TEST, popen_launch_server
# python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --tokenizer-path lmms-lab/llavanext-qwen-siglip-tokenizer --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384
class TestOpenAIVisionServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "liuhaotian/llava-v1.6-vicuna-7b"
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
cls.base_url = DEFAULT_URL_FOR_UNIT_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
......@@ -21,9 +31,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
api_key=cls.api_key,
other_args=[
"--chat-template",
"vicuna_v1.1",
"chatml-llava",
"--tokenizer-path",
"llava-hf/llava-1.5-7b-hf",
"lmms-lab/llavanext-qwen-siglip-tokenizer",
"--chunked-prefill-size",
"16384",
"--log-requests",
],
)
......@@ -68,6 +80,81 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def prepare_video_messages(self, video_path):
max_frames_num = 32
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, max_frames_num, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
base64_frames = []
for frame in frames:
pil_img = Image.fromarray(frame)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
base64_frames.append(base64_str)
messages = [{"role": "user", "content": []}]
frame_format = {
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"},
}
for base64_frame in base64_frames:
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
base64_frame
)
messages[0]["content"].append(frame_format.copy())
prompt = {"type": "text", "text": "Please describe the video in detail."}
messages[0]["content"].append(prompt)
return messages
def test_video_chat_completion(self):
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
messages = self.prepare_video_messages(file_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
stream=True,
)
print("-" * 30)
video_response = ""
for chunk in video_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
video_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
# Add assertions to validate the video response
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment