Commit f19343b0 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2418 failed with stages
in 0 seconds
# 模型唯一标识
modelCode=1419
# 模型名称
modelName=Qwen2.5-VL_pytorch
# 模型描述
modelDescription=Qwen2.5-VL增强了模型对时间和空间尺度的感知能力,并进一步简化了网络结构以提高模型效率。
# 应用场景
appScenario=推理,训练,对话问答,科研,教育,政府,金融
# 框架类型
frameType=Pytorch
# qwen-vl-utils
Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
## Install
```bash
pip install qwen-vl-utils
```
## Usage
### Qwen2VL
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos = process_vision_info(messages)
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
print(inputs)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
### Qwen2.5VL
```python
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
# based on the maximum tokens that the model can accept.
# export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
print(inputs)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
\ No newline at end of file
[project]
name = "qwen-vl-utils"
version = "0.0.10"
description = "Qwen Vision Language Model Utils - PyTorch"
authors = [
{ name = "Qwen Team", email = "chenkeqin.ckq@alibaba-inc.com" },
]
dependencies = [
"requests",
"pillow",
"av",
"packaging",
]
readme = "README.md"
requires-python = ">= 3.8"
license = {text = "Apache-2.0"}
keywords = [
'large language model',
'vision language model',
'qwen-vl',
'pytorch',
]
classifiers = [
'Development Status :: 4 - Beta',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
]
[project.urls]
Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
Repository = "https://github.com/QwenLM/Qwen2-VL.git"
Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
[project.optional-dependencies]
decord = [
"decord",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"torch",
"torchvision",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/qwen_vl_utils"]
[tool.ruff]
line-length = 119
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["qwen_vl_utils"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: ["decord"]
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
av==12.3.0
# via qwen-vl-utils
certifi==2022.12.7
# via requests
charset-normalizer==2.1.1
# via requests
decord==0.6.0
# via qwen-vl-utils
filelock==3.13.1
# via torch
# via triton
fsspec==2024.2.0
# via torch
idna==3.4
# via requests
jinja2==3.1.3
# via torch
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.1
# via torch
numpy==1.24.1
# via decord
# via torchvision
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.6.68
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
packaging==24.1
# via qwen-vl-utils
pillow==10.2.0
# via qwen-vl-utils
# via torchvision
requests==2.28.1
# via qwen-vl-utils
sympy==1.12
# via torch
torch==2.4.0
# via torchvision
torchvision==0.19.0
triton==3.0.0
# via torch
typing-extensions==4.9.0
# via torch
urllib3==1.26.13
# via requests
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: ["decord"]
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
av==12.3.0
# via qwen-vl-utils
certifi==2022.12.7
# via requests
charset-normalizer==2.1.1
# via requests
decord==0.6.0
# via qwen-vl-utils
idna==3.4
# via requests
numpy==1.24.4
# via decord
packaging==24.1
# via qwen-vl-utils
pillow==10.2.0
# via qwen-vl-utils
requests==2.28.1
# via qwen-vl-utils
urllib3==1.26.13
# via requests
from .vision_process import (
extract_vision_info,
fetch_image,
fetch_video,
process_vision_info,
smart_resize,
)
from __future__ import annotations
import base64
import logging
import math
import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
import requests
import torch
import torchvision
from packaging import version
from PIL import Image
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode
from typing import Optional
logger = logging.getLogger(__name__)
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
# Set the maximum number of video token inputs.
# Here, 128K represents the maximum number of input tokens for the VLLM model.
# Remember to adjust it according to your own configuration.
VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = to_rgb(image_obj)
## resize
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
nframes = total_frames / video_fps * fps
if nframes > total_frames:
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
return nframes
def _read_video_torchvision(
ele: dict,
) -> (torch.Tensor, float):
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
if "file://" in video_path:
video_path = video_path[7:]
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
video = video[idx]
return video, sample_fps
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
def _read_video_decord(
ele: dict,
) -> (torch.Tensor, float):
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
# TODO: support start_pts and end_pts
if 'video_start' in ele or 'video_end' in ele:
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
total_frames, video_fps = len(vr), vr.get_avg_fps()
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
return video, sample_fps
VIDEO_READER_BACKENDS = {
"decord": _read_video_decord,
"torchvision": _read_video_torchvision,
}
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
return video_reader_backend
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
try:
video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
except Exception as e:
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
if return_video_sample_fps:
return video, sample_fps
return video
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
if return_video_sample_fps:
return images, process_info.pop("fps", 2.0)
return images
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
vision_infos = extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
video_sample_fps_list.append(video_sample_fps)
video_inputs.append(video_input)
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return image_inputs, video_inputs, {'fps': video_sample_fps_list}
return image_inputs, video_inputs
# Core dependencies
gradio==5.4.0
gradio_client==1.4.2
qwen-vl-utils==0.0.10
transformers-stream-generator==0.0.4
torch
torchvision
git+https://github.com/huggingface/transformers.git
accelerate
av
# Optional dependency
# Uncomment the following line if you need flash-attn
# flash-attn==2.6.1
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import re
from argparse import ArgumentParser
from threading import Thread
import gradio as gr
import torch
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
DEFAULT_CKPT_PATH = 'Qwen/Qwen2.5-VL-7B-Instruct'
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default=DEFAULT_CKPT_PATH,
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
args = parser.parse_args()
return args
def _load_model_processor(args):
if args.cpu_only:
device_map = 'cpu'
else:
device_map = 'auto'
# Check if flash-attn2 flag is enabled and load model accordingly
if args.flash_attn2:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(args.checkpoint_path,
torch_dtype='auto',
attn_implementation='flash_attention_2',
device_map=device_map)
else:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)
processor = AutoProcessor.from_pretrained(args.checkpoint_path)
return model, processor
def _parse_text(text):
lines = text.split('\n')
lines = [line for line in lines if line != '']
count = 0
for i, line in enumerate(lines):
if '```' in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = '<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace('`', r'\`')
line = line.replace('<', '&lt;')
line = line.replace('>', '&gt;')
line = line.replace(' ', '&nbsp;')
line = line.replace('*', '&ast;')
line = line.replace('_', '&lowbar;')
line = line.replace('-', '&#45;')
line = line.replace('.', '&#46;')
line = line.replace('!', '&#33;')
line = line.replace('(', '&#40;')
line = line.replace(')', '&#41;')
line = line.replace('$', '&#36;')
lines[i] = '<br>' + line
text = ''.join(lines)
return text
def _remove_image_special(text):
text = text.replace('<ref>', '').replace('</ref>', '')
return re.sub(r'<box>.*?(</box>|$)', '', text)
def _is_video_file(filename):
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
return any(filename.lower().endswith(ext) for ext in video_extensions)
def _gc():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _transform_messages(original_messages):
transformed_messages = []
for message in original_messages:
new_content = []
for item in message['content']:
if 'image' in item:
new_item = {'type': 'image', 'image': item['image']}
elif 'text' in item:
new_item = {'type': 'text', 'text': item['text']}
elif 'video' in item:
new_item = {'type': 'video', 'video': item['video']}
else:
continue
new_content.append(new_item)
new_message = {'role': message['role'], 'content': new_content}
transformed_messages.append(new_message)
return transformed_messages
def _launch_demo(args, model, processor):
def call_local_model(model, processor, messages):
messages = _transform_messages(messages)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt')
inputs = inputs.to(model.device)
tokenizer = processor.tokenizer
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
generated_text = ''
for new_text in streamer:
generated_text += new_text
yield generated_text
def create_predict_fn():
def predict(_chatbot, task_history):
nonlocal model, processor
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print('User: ' + _parse_text(query))
history_cp = copy.deepcopy(task_history)
full_response = ''
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
if _is_video_file(q[0]):
content.append({'video': f'file://{q[0]}'})
else:
content.append({'image': f'file://{q[0]}'})
else:
content.append({'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'text': a}]})
content = []
messages.pop()
for response in call_local_model(model, processor, messages):
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
yield _chatbot
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print('Qwen-VL-Chat: ' + _parse_text(full_response))
yield _chatbot
return predict
def create_regenerate_fn():
def regenerate(_chatbot, task_history):
nonlocal model, processor
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
return regenerate
predict = create_predict_fn()
regenerate = create_regenerate_fn()
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ''
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value='')
def reset_state(_chatbot, task_history):
task_history.clear()
_chatbot.clear()
_gc()
return []
with gr.Blocks() as demo:
gr.Markdown("""\
<p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>"""
)
gr.Markdown("""<center><font size=8>Qwen2.5-VL</center>""")
gr.Markdown("""\
<center><font size=3>This WebUI is based on Qwen2.5-VL, developed by Alibaba Cloud.</center>""")
gr.Markdown("""<center><font size=3>本WebUI基于Qwen2.5-VL。</center>""")
chatbot = gr.Chatbot(label='Qwen2.5-VL', elem_classes='control-height', height=500)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton('📁 Upload (上传文件)', file_types=['image', 'video'])
submit_btn = gr.Button('🚀 Submit (发送)')
regen_btn = gr.Button('🤔️ Regenerate (重试)')
empty_bin = gr.Button('🧹 Clear History (清除历史)')
submit_btn.click(add_text, [chatbot, task_history, query],
[chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
gr.Markdown("""\
<font size=2>Note: This demo is governed by the original license of Qwen2.5-VL. \
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
including hate speech, violence, pornography, deception, etc. \
(注:本演示受Qwen2.5-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
demo.queue().launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def main():
args = _get_args()
model, processor = _load_model_processor(args)
_launch_demo(args, model, processor)
if __name__ == '__main__':
main()
from threading import Thread
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
import openai
import copy
import shutil
from PIL import Image
from argparse import ArgumentParser
import io
import pathlib
import gradio as gr
import time
import base64
import pathlib
from typing import Dict
import gradio as gr
import os
import time
from qwen_vl_utils import process_vision_info, smart_resize
import tempfile
import time
import imagesize
import uuid
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
image_transform = None
oss_reader = None
MAX_SEQ_LEN = 32000
DEFAULT_CKPT_PATH = 'Qwen/Qwen2.5-VL-7B-Instruct'
def compute_seqlen_estimated(tokenizer, json_input, sample_strategy_func):
total_seq_len, img_seq_len, text_seq_len = 0, 0, 0
for chat_block in json_input:
chat_block['seq_len'] = 4
role_length = len(tokenizer.tokenize(chat_block['role']))
chat_block['seq_len'] += role_length
text_seq_len += role_length
for element in chat_block['content']:
if 'image' in element:
if 'width' not in element:
element['width'], element['height'] = imagesize.get(
element['image'].split('file://')[1])
height, width = element['height'], element['width']
height, width = sample_strategy_func(height, width)
resized_height, resized_width = smart_resize(
height, width, max_pixels=14*14*4*5120) # , min_pixels=14*14*4*512
seq_len = resized_height * resized_width // 28 // 28 + 2 # add img_bos & img_eos
element.update({
'resized_height': resized_height,
'resized_width': resized_width,
'seq_len': seq_len,
})
img_seq_len += element['seq_len']
chat_block['seq_len'] += element['seq_len']
elif 'video' in element:
if isinstance(element['video'], (list, tuple)):
if 'width' not in element:
element['width'], element['height'] = imagesize.get(
element['video'][0].split('file://')[1])
height, width = element['height'], element['width']
height, width = sample_strategy_func(height, width)
resized_height, resized_width = smart_resize(
height, width, max_pixels=14*14*4*5120) # , min_pixels=14*14*4*512
seq_len = (resized_height * resized_width // 28 // 28) * \
(len(element['video'])//2)+2 # add img_bos & img_eos
element.update({
'resized_height': resized_height,
'resized_width': resized_width,
'seq_len': seq_len,
})
img_seq_len += element['seq_len']
chat_block['seq_len'] += element['seq_len']
else:
raise NotImplementedError
elif 'text' in element:
if 'seq_len' in element:
text_seq_len += element['seq_len']
else:
element['seq_len'] = len(
tokenizer.tokenize(element['text']))
text_seq_len += element['seq_len']
chat_block['seq_len'] += element['seq_len']
elif 'prompt' in element:
if 'seq_len' in element:
text_seq_len += element['seq_len']
else:
element['seq_len'] = len(
tokenizer.tokenize(element['prompt']))
text_seq_len += element['seq_len']
chat_block['seq_len'] += element['seq_len']
else:
raise ValueError('Unknown element: ' + str(element))
total_seq_len += chat_block['seq_len']
assert img_seq_len + text_seq_len + 4 * len(json_input) == total_seq_len
total_seq_len += 1
return {
'content': json_input,
'img_seq_len': img_seq_len,
'text_seq_len': text_seq_len,
'seq_len': total_seq_len,
}
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default=DEFAULT_CKPT_PATH,
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true',
help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int,
default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str,
default='127.0.0.1', help='Demo server name.')
args = parser.parse_args()
return args
def _load_model_processor(args):
if args.cpu_only:
device_map = 'cpu'
else:
device_map = 'auto'
# Check if flash-attn2 flag is enabled and load model accordingly
if args.flash_attn2:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(args.checkpoint_path,
torch_dtype='auto',
attn_implementation='flash_attention_2',
device_map=device_map)
else:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.checkpoint_path, device_map=device_map)
processor = AutoProcessor.from_pretrained(args.checkpoint_path)
return model, processor
class ChatSessionState:
def __init__(self, session_id: str):
self.session_id: str = session_id
self.system_prompt: str = 'You are a helpful assistant.'
self.model_name = ''
self.image_cache = []
def _transform_messages(original_messages):
transformed_messages = []
for message in original_messages:
new_content = []
for item in message['content']:
if 'image' in item:
new_item = {'type': 'image', 'image': item['image']}
elif 'text' in item:
new_item = {'type': 'text', 'text': item['text']}
elif 'video' in item:
new_item = {'type': 'video', 'video': item['video']}
else:
continue
new_content.append(new_item)
new_message = {'role': message['role'], 'content': new_content}
transformed_messages.append(new_message)
return transformed_messages
class Worker:
def __init__(self):
self.uids = []
capture_image_dir = os.path.join("/tmp/captured_images")
os.makedirs(capture_image_dir, exist_ok=True)
self.capture_image_dir = capture_image_dir # uid-to-messages
self.save_dir = dict()
self.messages = dict() # uid-to-messages
self.resized_width, self.resized_height = 640, 420
# self.message_truncate = 0
self.message_truncate = {}
self.chat_session_states: Dict[str, ChatSessionState] = {}
self.image_cache = {}
def convert_image_to_base64(self, file_name):
if file_name not in self.image_cache:
self.image_cache[file_name] = {}
if 'data_url' not in self.image_cache[file_name]:
with open(file_name, 'rb') as f:
self.image_cache[file_name]['data_url'] = 'data:image/png;base64,' + \
base64.b64encode(f.read()).decode('utf-8')
assert self.image_cache[file_name]['data_url']
return self.image_cache[file_name]['data_url']
def get_session_state(self, session_id: str) -> ChatSessionState:
"""
Retrieves the chat session state object for a given session ID.
If the session ID does not exist in the currently managed session states,
a new session state object is created and added to the list of managed sessions.
Parameters:
session_id: The unique identifier for the session.
Returns:
The session state object corresponding to the session ID.
"""
# Check if the current session state collection already contains this session ID
if session_id not in self.chat_session_states:
# If it does not exist, create a new session state object and add it to the collection
self.chat_session_states[session_id] = ChatSessionState(session_id)
# Return the corresponding session state object
return self.chat_session_states[session_id]
def get_message_truncate(self, session_id):
if session_id not in self.message_truncate:
self.message_truncate[session_id] = 0
return self.message_truncate[session_id]
def truncate_messages_adaptive(self, messages):
while True:
seq_len = compute_seqlen_estimated(tokenizer, copy.deepcopy(
messages), sample_strategy_func=lambda h, w: (h, w))['seq_len']
if seq_len < MAX_SEQ_LEN:
break
# Remove the first block in content history:
if len(messages[0]['content']) > 0 and 'video' in messages[0]['content'][0]:
messages[0]['content'][0]['video'] = messages[0]['content'][0]['video'][2:]
if len(messages[0]['content'][0]['video']) == 0:
messages[0]['content'] = messages[0]['content'][1:]
else:
messages[0]['content'] = messages[0]['content'][1:]
# If the first block is empty, remove it:
if len(messages[0]['content']) == 0:
messages.pop(0)
# If role is assistant, remove the first block in content history:
if messages[0]['role'] == 'assistant':
messages.pop(0)
return messages
def truncate_messages_by_count(self, messages, cnt):
for i in range(cnt):
# Remove the first block in content history:
if len(messages[0]['content']) > 0 and 'video' in messages[0]['content'][0]:
messages[0]['content'][0]['video'] = messages[0]['content'][0]['video'][2:]
if len(messages[0]['content'][0]['video']) == 0:
messages[0]['content'] = messages[0]['content'][1:]
else:
messages[0]['content'] = messages[0]['content'][1:]
# If the first block is empty, remove it:
if len(messages[0]['content']) == 0:
messages.pop(0)
# If role is assistant, remove the first block in content history:
if messages[0]['role'] == 'assistant':
messages.pop(0)
def get_save_dir(self, session_id):
if self.save_dir.get(session_id) is None:
temp_dir = tempfile.mkdtemp(dir=self.capture_image_dir)
self.save_dir[session_id] = temp_dir
return self.save_dir[session_id]
def get_messages(self, session_id):
if self.messages.get(session_id) is None:
self.messages[session_id] = []
return self.messages[session_id]
def update_messages(self, session_id, role, content):
if self.messages.get(session_id) is None:
self.messages[session_id] = []
messages = self.messages[session_id]
if len(messages) == 0 or messages[-1]["role"] != role:
messages.append({
"role": role,
"content": [content]
})
elif "video" in content and isinstance(content["video"], (list, tuple)) and "video" in messages[-1]["content"][-1] and isinstance(messages[-1]["content"][-1]["video"], (list, tuple)):
messages[-1]["content"][-1]['video'].extend(content["video"])
else:
# If content and last message are all with type text, merge them
if 'text' in messages[-1]["content"][-1] and 'text' in content:
messages[-1]["content"][-1]['text'] += content["text"]
else:
messages[-1]["content"].append(content)
self.messages[session_id] = messages
def get_timestamp(self):
return time.time()
def chat(self, messages, request: gr.Request):
messages = _transform_messages(messages)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs,
videos=video_inputs, padding=True, return_tensors='pt')
inputs = inputs.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
yield new_text
def add_text(self, history, text, request: gr.Request):
session_id = request.session_hash
session_state: ChatSessionState = self.get_session_state(
request.session_hash)
if len(session_state.image_cache) > 0:
for i, (timestamp, image_path) in enumerate(session_state.image_cache):
if i % 2 == 0:
content = {"video": [f"file://{image_path}"]}
else:
content["video"].append(f"file://{image_path}")
self.update_messages(
session_id, role="user", content=content)
if i == len(session_state.image_cache)-1 and i % 2 == 0:
content["video"].append(content["video"][-1])
self.update_messages(
session_id, role="user", content=content)
session_state.image_cache.clear()
self.update_messages(session_id, role="user", content={
"type": "text", "text": text})
history = history + [(text, None)]
return history, ""
def add_file(self, history, file, request: gr.Request):
session_id = request.session_hash
session_state: ChatSessionState = self.get_session_state(session_id)
if isinstance(file, str) and file.startswith('data:'):
# get binary bytes
data = base64.b64decode(file.split('base64,')[1])
# Create a file name using uuid
filename = f'{uuid.uuid4()}.jpg'
save_dir = self.get_save_dir(session_id)
savename = os.path.join(save_dir, filename)
# Save the file
with open(savename, 'wb') as f:
f.write(data)
self.update_messages(session_id, role="user", content={
"image": f"file://{savename}"})
else:
filename = os.path.basename(file.name)
save_dir = self.get_save_dir(session_id)
savename = os.path.join(save_dir, filename)
if file.name.endswith('.mp4') or file.name.endswith('.mov'):
shutil.copy(file.name, savename)
os.makedirs(file.name + '.frames', exist_ok=True)
os.system(
f'ffmpeg -i {file.name} -vf "scale=320:-1" -r 0.25 {file.name}.frames/%d.jpg')
file_index = 1
frame_list = []
while True:
if os.path.isfile(os.path.join(f'{file.name}.frames/{file_index}.jpg')):
frame_list.append(os.path.join(
f'file://{file.name}.frames/{file_index}.jpg'))
file_index += 1
else:
break
if len(frame_list) % 2 != 0:
frame_list = frame_list[1:]
self.update_messages(session_id, role="user", content={
"video": frame_list})
else:
shutil.copy(file.name, savename)
self.update_messages(session_id, role="user", content={
"image": f"file://{savename}"})
history = history + [((savename,), None)]
return history
def add_image_to_streaming_cache(self, file, width, height, request: gr.Request):
session_id = request.session_hash
session_state: ChatSessionState = self.get_session_state(session_id)
timestamp = self.get_timestamp()
# If file is an image url starswith data:, save it to the session directory
if isinstance(file, str) and file.startswith('data:'):
# get binary bytes
data = base64.b64decode(file.split('base64,')[1])
width, height = int(width), int(height)
# Load the image using PIL
image = Image.open(io.BytesIO(data))
# If width == -1, no need to scale the image
if width == -1:
pass
else:
# If height == -1, keep aspect ratio
if height == -1:
height = round(width * image.height / float(image.width))
image = image.resize((width, height), Image.LANCZOS)
# Create a file name using uuid
filename = f'{uuid.uuid4()}.jpg'
save_dir = self.get_save_dir(session_id)
savename = os.path.join(save_dir, filename)
# Save the file
image.save(savename, "JPEG")
else:
filename = os.path.basename(file.name)
save_dir = self.get_save_dir(session_id)
savename = os.path.join(save_dir, filename)
shutil.copy(file.name, savename)
session_state.image_cache.append((timestamp, savename))
def response(self, chatbot_messages, request: gr.Request):
session_id = request.session_hash
messages = self.get_messages(session_id)
self.truncate_messages_adaptive(messages)
messages = copy.deepcopy(messages)
chatbot_messages = copy.deepcopy(chatbot_messages)
if chatbot_messages is None:
chatbot_messages = []
truncate_count = 0
while True:
compiled_messages = copy.deepcopy(messages)
self.truncate_messages_by_count(
compiled_messages, cnt=truncate_count)
# Convert file:// image urls to data:base64 urls
for message in compiled_messages:
for content in message['content']:
if 'image' in content:
if content['image'].startswith('file://'):
content['image'] = self.convert_image_to_base64(
content['image'][7:])
elif 'video' in content and isinstance(content['video'], (list, tuple)):
for frame_i in range(len(content['video'])):
if content['video'][frame_i].startswith('file://'):
content['video'][frame_i] = self.convert_image_to_base64(
content['video'][frame_i][7:])
rep = self.chat(compiled_messages, request=request)
try:
for content in rep:
if not content:
continue
self.update_messages(session_id, role="assistant", content={
"type": "text", "text": content})
if not chatbot_messages[-1][-1]:
chatbot_messages[-1][-1] = content
else:
chatbot_messages[-1][-1] += content
yield chatbot_messages
break
except openai.BadRequestError as e:
print(e)
if 'maximum context length' not in str(e):
raise e
if self.messages[session_id][-1]['role'] == 'assistant':
chatbot_messages[-1][-1] = ''
self.messages[session_id] = self.messages[session_id][:-1]
# self.messages[session_id][-1]['content'][-1] = {'text': ''}
self.message_truncate[session_id] += 1
recorder_js = pathlib.Path('recorder.js').read_text()
main_js = pathlib.Path('main.js').read_text()
GLOBAL_JS = pathlib.Path('global.js').read_text().replace('let recorder_js = null;', recorder_js).replace(
'let main_js = null;', main_js)
def main():
with gr.Blocks(js=GLOBAL_JS) as demo:
gr.Markdown("""\
<p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>"""
)
gr.Markdown("""<center><font size=8>Qwen2-VL</center>""")
gr.Markdown("""\
<center><font size=3>This WebUI is based on Qwen2-VL, developed by Alibaba Cloud.</center>""")
gr.Markdown("""<center><font size=3>本WebUI基于Qwen2-VL。</center>""")
with gr.Accordion("Advanced Settings", open=False):
with gr.Accordion("System Prompt", open=False):
textbox_system_prompt = gr.Textbox(
value="You are a helpful assistant.", label="System Prompt")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Camera"):
image_camera = gr.Image(sources='webcam', label="Camera Preview",
mirror_webcam=False, elem_id="gradio_image_camera_preview")
with gr.Accordion("Camera Settings", open=False):
with gr.Row():
camera_frame_interval = gr.Textbox(
"1", label="Frame interval or (1 / FPS)", elem_id="gradio_camera_frame_interval", interactive=True)
with gr.Row():
camera_width = gr.Textbox(
"640", label="Width (-1 = original resolution)")
camera_height = gr.Textbox(
"-1", label="Height (-1 = keep aspect ratio)")
with gr.Row():
button_camera_stream = gr.Button(
"Stream", elem_id="gradio_button_camera_stream")
button_camera_snapshot = gr.Button(
"Snapshot", elem_id="gradio_button_camera_snapshot")
button_camera_stream_submit = gr.Button(
"Snapshot", elem_id="gradio_button_camera_stream_submit", visible=False)
with gr.Tab("Screen"):
image_screen = gr.Image(
sources='webcam', label="Screen Preview", elem_id="gradio_image_screen_preview")
with gr.Accordion("Screen Settings", open=False):
with gr.Row():
screen_frame_interval = gr.Textbox(
"5", label="Frame interval or (1 / FPS)", elem_id="gradio_screen_frame_interval", interactive=True)
with gr.Row():
screen_width = gr.Textbox(
"-1", label="Width (-1 = original resolution)")
screen_height = gr.Textbox(
"-1", label="Height (-1 = keep aspect ratio)")
with gr.Row():
button_screen_stream = gr.Button(
"Stream", elem_id="gradio_button_screen_stream")
button_screen_snapshot = gr.Button(
"Snapshot", elem_id="gradio_button_screen_snapshot")
button_screen_stream_submit = gr.Button(
"Snapshot", elem_id="gradio_button_screen_stream_submit", visible=False)
with gr.Column(scale=2):
chatbot = gr.Chatbot([], elem_id="chatofa", height=500)
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or upload an image",
container=False,
scale=5,
)
btn = gr.UploadButton(
"📁", file_types=["image", "video", "audio"], scale=1)
txt.submit(
fn=worker.add_text,
inputs=[chatbot, txt],
outputs=[chatbot, txt]
).then(
fn=worker.response,
inputs=[chatbot],
outputs=chatbot
)
btn.upload(
worker.add_file,
inputs=[chatbot, btn],
outputs=[chatbot]
)
# Camera
button_camera_snapshot.click(
worker.add_file,
inputs=[chatbot, button_camera_snapshot],
outputs=[chatbot],
js="(p1, p2) => [p1, window.getCameraFrame()]",
)
button_camera_stream_submit.click(
worker.add_image_to_streaming_cache,
inputs=[button_camera_stream_submit,
camera_width, camera_height],
outputs=[],
js="(p1, p2, p3) => [window.getCameraFrame(), p2, p3]",
)
button_camera_stream.click(
lambda x: None,
inputs=[button_camera_stream],
outputs=[],
js="(p1, p2) => (window.startCameraStreaming())"
)
# Screen
button_screen_snapshot.click(
worker.add_file,
inputs=[chatbot, button_screen_snapshot],
outputs=[chatbot],
js="(p1, p2) => [p1, window.getScreenshotFrame()]",
)
button_screen_stream_submit.click(
worker.add_image_to_streaming_cache,
inputs=[button_screen_stream_submit,
screen_width, screen_height],
outputs=[],
js="(p1, p2, p3) => [window.getScreenshotFrame(), p2, p3]",
)
button_screen_stream.click(
lambda x: None,
inputs=[button_screen_stream],
outputs=[],
js="(p1, p2) => (window.startScreenStreaming())"
)
with gr.Row():
gr.Markdown("""\
<font size=2>Note: This demo is governed by the original license of Qwen2-VL. \
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
including hate speech, violence, pornography, deception, etc. \
(注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
demo.launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
if __name__ == '__main__':
worker = Worker()
args = _get_args()
model, processor = _load_model_processor(args)
tokenizer = processor.tokenizer
main()
// Setup if needed and start recording.
async () => {
screenButton = document.getElementById('gradio_image_screen_preview'); //.querySelector('button');
//eventListers = window.getEventListeners(screenButton);
//screenButton.removeEventListener('click', eventListers['click'][0].listener, eventListers['click'][0].useCapture)
// If window.getScreenshotFrameDoes not exist:
if (!window.getScreenshotFrame) {
// Define the function
window.getScreenshotFrame = () => {
// Get the video element
var video = document.getElementById('gradio_image_screen_preview').querySelector('video');
// Get the canvas element
var canvas = document.getElementsByTagName('canvas')[0];
// Get the canvas context
var ctx = canvas.getContext('2d');
// Set the canvas size to match the video size
canvas.width = video.videoWidth;
canvas.height = video.videoHeight;
// Draw the current frame on the canvas
ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
return canvas.toDataURL('image/jpeg', 1.0);
}
}
if (!window.getCameraFrame) {
// Define the function
window.getCameraFrame = () => {
// Get the video element
var video = document.getElementById('gradio_image_camera_preview').querySelector('video');
// Get the canvas element
var canvas = document.getElementsByTagName('canvas')[0];
// Get the canvas context
var ctx = canvas.getContext('2d');
// Set the canvas size to match the video size
canvas.width = video.videoWidth;
canvas.height = video.videoHeight;
// Draw the current frame on the canvas
ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
return canvas.toDataURL('image/jpeg', 1.0);
}
}
if(!window.startCameraStreaming) {
window.startCameraStreaming = () => {
var intervalString = document.getElementById("gradio_camera_frame_interval").getElementsByTagName('textarea')[0].value;
var interval = parseFloat(intervalString) * 1000;
console.log("Start camera, interval: " + interval + "")
window.cameraIntervalId = setInterval(() => document.getElementById('gradio_button_camera_stream_submit').click(), interval)
}
}
if(!window.startScreenStreaming) {
window.startScreenStreaming = () => {
var intervalString = document.getElementById("gradio_screen_frame_interval").getElementsByTagName('textarea')[0].value;
var interval = parseFloat(intervalString) * 1000;
console.log("Start screen, interval: " + interval + "")
window.screenIntervalId = setInterval(() => document.getElementById('gradio_button_screen_stream_submit').click(), interval)
}
}
if(document.getElementsByTagName('canvas').length <= 0) {
var canvasElement = document.createElement('canvas');
canvasElement.style.position = 'fixed';
canvasElement.style.bottom = '0px';
canvasElement.style.right = '0px';
canvasElement.style.width = '32px'; // 您可以根据需要调整宽度
canvasElement.style.height = '32px'; // 高度自适应,保持宽高比
document.body.appendChild(canvasElement);
}
screenButton.addEventListener('click', function (e) {
//alert('Hello world! 666')
// Set up recording functions if not already initialized
if (!window.startRecording) {
let recorder_js = null;
let main_js = null;
}
if (!window.getVideoSnpapshot) {
// Synchronous function to get a video snapshot
window.getVideoSnpapshot = () => {
var canvas = document.getElementsByTagName('canvas')[0];
var ctx = canvas.getContext('2d');
// window.getComputedStyle(canvas)
if(canvas.width != canvas.clientWidth) {
canvas.width = canvas.clientWidth
}
if(canvas.height != canvas.clientHeight) {
canvas.height = canvas.clientHeight
}
if(!window.videoPlaying) {
return "Record";
}
ctx.drawImage(document.getElementById('video_screenshot'), 0, 0, canvas.clientWidth, canvas.clientHeight);
console.log(canvas.toDataURL('image/jpeg', 1.0));
return canvas.toDataURL('image/jpeg', 1.0);
};
}
e.stopPropagation();
window.startRecording();
}, true)
}
\ No newline at end of file
// main.js
if (!ScreenCastRecorder.isSupportedBrowser()) {
console.error("Screen Recording not supported in this browser");
}
let recorder;
let outputBlob;
const stopRecording = () => __awaiter(void 0, void 0, void 0, function* () {
let currentState = "RECORDING";
// We should do nothing if the user try to stop recording when it is not started
if (currentState === "OFF" || recorder == null) {
return;
}
// if (currentState === "COUNTDOWN") {
// this.setState({
// currentState: "OFF",
// })
// }
if (currentState === "RECORDING") {
if (recorder.getState() === "inactive") {
// this.setState({
// currentState: "OFF",
// })
console.log("Inactive");
}
else {
outputBlob = yield recorder.stop();
console.log("Done recording");
// this.setState({
// outputBlob,
// currentState: "PREVIEW_FILE",
// })
window.currentState = "PREVIEW_FILE";
const videoSource = URL.createObjectURL(outputBlob);
window.videoSource = videoSource;
const fileName = "recording";
const link = document.createElement("a");
link.setAttribute("href", videoSource);
link.setAttribute("download", `${fileName}.webm`);
link.click();
}
}
});
const startRecording = () => __awaiter(void 0, void 0, void 0, function* () {
const recordAudio = true;
recorder = new ScreenCastRecorder({
recordAudio,
onErrorOrStop: () => stopRecording(),
});
try {
yield recorder.initialize();
}
catch (e) {
console.warn(`ScreenCastRecorder.initialize error: ${e}`);
// this.setState({ currentState: "UNSUPPORTED" })
window.currentState = "UNSUPPORTED";
return;
}
// this.setState({ currentState: "COUNTDOWN" })
const hasStarted = recorder.start();
if (hasStarted) {
// this.setState({
// currentState: "RECORDING",
// })
console.log("Started recording");
window.currentState = "RECORDING";
}
else {
stopRecording().catch(err => console.warn(`withScreencast.stopRecording threw an error: ${err}`));
}
});
// Set global functions to window.
window.startRecording = startRecording;
window.stopRecording = stopRecording;
\ No newline at end of file
// recorder.js
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
const BLOB_TYPE = "video/webm";
class ScreenCastRecorder {
/** True if the current browser likely supports screencasts. */
static isSupportedBrowser() {
return (navigator.mediaDevices != null &&
navigator.mediaDevices.getUserMedia != null &&
navigator.mediaDevices.getDisplayMedia != null &&
MediaRecorder.isTypeSupported(BLOB_TYPE));
}
constructor({ recordAudio, onErrorOrStop }) {
this.recordAudio = recordAudio;
this.onErrorOrStopCallback = onErrorOrStop;
this.inputStream = null;
this.recordedChunks = [];
this.mediaRecorder = null;
}
/**
* This asynchronous method will initialize the screen recording object asking
* for permissions to the user which are needed to start recording.
*/
initialize() {
return __awaiter(this, void 0, void 0, function* () {
const desktopStream = yield navigator.mediaDevices.getDisplayMedia({
video: true,
});
let tracks = desktopStream.getTracks();
if (this.recordAudio) {
const voiceStream = yield navigator.mediaDevices.getUserMedia({
video: false,
audio: true,
});
tracks = tracks.concat(voiceStream.getAudioTracks());
}
this.recordedChunks = [];
this.inputStream = new MediaStream(tracks);
let videoElement = document.getElementById('gradio_image_screen_preview').querySelectorAll('video')[0]
// Remove hide class from videoElement
videoElement.classList.remove('hide');
videoElement.classList.remove('flip');
// Set src to inputStream
videoElement.srcObject = this.inputStream;
window.screenInputStream = this.inputStream;
videoElement.play();
// Get width and height of inputStream
//window.videoElement.width = this.inputStream.getVideoTracks()[0].getSettings().width;
window.videoPlaying = 1;
/*setInterval(() => {
document.getElementById("component-2").click()
}, 5000);*/
this.mediaRecorder = new MediaRecorder(this.inputStream, {
mimeType: BLOB_TYPE,
});
this.mediaRecorder.ondataavailable = e => this.recordedChunks.push(e.data);
});
}
getState() {
if (this.mediaRecorder) {
return this.mediaRecorder.state;
}
return "inactive";
}
/**
* This method will start the screen recording if the user has granted permissions
* and the mediaRecorder has been initialized
*
* @returns {boolean}
*/
start() {
if (!this.mediaRecorder) {
console.warn(`ScreenCastRecorder.start: mediaRecorder is null`);
return false;
}
const logRecorderError = (e) => {
console.warn(`mediaRecorder.start threw an error: ${e}`);
};
this.mediaRecorder.onerror = (e) => {
logRecorderError(e);
this.onErrorOrStopCallback();
};
this.mediaRecorder.onstop = () => this.onErrorOrStopCallback();
try {
this.mediaRecorder.start();
}
catch (e) {
logRecorderError(e);
return false;
}
return true;
}
/**
* This method will stop recording and then return the generated Blob
*
* @returns {(Promise|undefined)}
* A Promise which will return the generated Blob
* Undefined if the MediaRecorder could not initialize
*/
stop() {
if (!this.mediaRecorder) {
return undefined;
}
let resolver;
const promise = new Promise(r => {
resolver = r;
});
this.mediaRecorder.onstop = () => resolver();
this.mediaRecorder.stop();
if (this.inputStream) {
this.inputStream.getTracks().forEach(s => s.stop());
this.inputStream = null;
}
return promise.then(() => this.buildOutputBlob());
}
buildOutputBlob() {
return new Blob(this.recordedChunks, { type: BLOB_TYPE });
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment