Commit f8d86cb0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2605 canceled with stages
File added
icon.png

53.8 KB

import soundfile as sf
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
'''
FORCE_QWENVL_VIDEO_READER=decord # 强制使用decord 后端
'''
# default: Load the model on the available device(s)
model = Qwen2_5OmniModel.from_pretrained("Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2")
# We recommend enabling flash_attention_2 for better acceleration and memory saving.
# model = Qwen2_5OmniModel.from_pretrained(
# "Qwen/Qwen2.5-Omni-7B",
# torch_dtype="auto",
# device_map="auto",
# attn_implementation="flash_attention_2",
# )
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
conversation = [
{
"role": "system",
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.",
},
{
"role": "user",
"content": [
{"type": "video", "video": "./draw.mp4"},
],
},
]
# Preparation for inference
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(conversation, use_audio_in_video=True)
inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True)
inputs = inputs.to(model.device).to(model.dtype)
# Inference: Generation of the output text and audio
text_ids, audio = model.generate(**inputs, use_audio_in_video=True)
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(text)
sf.write(
"output.wav",
audio.reshape(-1).detach().cpu().numpy(),
samplerate=24000,
)
# 模型编码
modelCode=1478
# 模型名称
modelName=Qwen2.5-Omni_pytorch
# 模型描述
modelDescription=7B参数完成看、听、说、写,端到端多模态大模型支持文本、图像、音频和视频输入。
# 应用场景
appScenario=推理,对话问答,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
File added
# qwen-omni-utils
Qwen-Omni Utils contains a set of helper functions for processing and integrating visual and audio language information with Qwen-Omni Model.
## Install
```bash
pip install qwen-omni-utils
```
## Usage
### Qwen2Omni
```python
from transformers import Qwen2_5OmniModel, AutoProcessor
from qwen_omni_utils import process_mm_info
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
# Audio
## Local audio path
[{"role": "user", "content": [{"type": "audio", "audio": "file:///path/to/audio1.wav"}, {"type": "text", "text": "Describe this audio."}]}],
## Numpy format audio
[{"role": "user", "content": [{"type": "audio", "audio": numpy_audio}, {"type": "text", "text": "Describe this audio."}]}],
## Remote audio
[{"role": "user", "content": [{"type": "audio", "audio": "https://path/to/audio.wav"}, {"type": "text", "text": "Describe this audio."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2_5OmniModel.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
audios, images, videos = process_mm_info(messages)
inputs = processor(text=text, images=images, videos=videos, audios=audios, padding=True, return_tensors="pt")
print(inputs)
generated_ids, generate_wav = model.generate(**inputs)
print(generated_ids)
```
### Qwen2VL
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_omni_utils import process_vision_info
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos = process_vision_info(messages)
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
print(inputs)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
### Qwen2.5VL
```python
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_omni_utils import process_vision_info
# You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
# based on the maximum tokens that the model can accept.
# export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
print(inputs)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
\ No newline at end of file
[project]
name = "qwen-omni-utils"
version = "0.0.3"
description = "Qwen Omni Language Model Utils - PyTorch"
authors = [
{ name = "Qwen Team", email = "lvyuanjun.lyj@alibaba-inc.com" },
]
dependencies = [
"requests",
"pillow",
"av",
"packaging",
"librosa",
]
readme = "README.md"
requires-python = ">= 3.8"
license = {text = "Apache-2.0"}
keywords = [
'large language model',
'vision language model',
'qwen-omni',
'pytorch',
]
classifiers = [
'Development Status :: 4 - Beta',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
]
[project.urls]
Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
Repository = "https://github.com/QwenLM/Qwen2-VL.git"
Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
[project.optional-dependencies]
decord = [
"decord",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"torch",
"torchvision",
"torchaudio",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/qwen_omni_utils"]
[tool.ruff]
line-length = 119
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["qwen_omni_utils"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
\ No newline at end of file
from .audio_process import process_audio_info
from .vision_process import (
extract_vision_info,
fetch_image,
fetch_video,
process_vision_info,
smart_resize,
)
def process_mm_info(conversations, use_audio_in_video, return_video_kwargs=False):
audios = process_audio_info(conversations, use_audio_in_video)
vision = process_vision_info(conversations, return_video_kwargs=return_video_kwargs)
return (audios,) + vision
import audioread
import av
import librosa
import numpy as np
def _check_if_video_has_audio(video_path):
container = av.open(video_path)
audio_streams = [stream for stream in container.streams if stream.type == "audio"]
if not audio_streams:
return False
return True
def process_audio_info(conversations: list[dict] | list[list[dict]], use_audio_in_video: bool):
audios = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if not isinstance(message["content"], list):
continue
for ele in message["content"]:
if ele["type"] == "audio":
if "audio" in ele:
path = ele["audio"]
if path.startswith("http://") or path.startswith("https://"):
audios.append(librosa.load(audioread.ffdec.FFmpegAudioFile(path), sr=16000)[0])
elif isinstance(path, np.ndarray):
if path.ndim > 1:
raise ValueError("Support only mono audio")
audios.append(path)
elif path.startswith("file://"):
audios.append(librosa.load(path[len("file://") :], sr=16000)[0])
else:
audios.append(librosa.load(path, sr=16000)[0])
else:
raise ValueError("Unknown audio {}".format(ele))
if use_audio_in_video and ele["type"] == "video":
if "video" in ele:
path = ele["video"]
assert _check_if_video_has_audio(
path
), "Video must has audio track when use_audio_in_video=True"
if path.startswith("http://") or path.startswith("https://"):
audios.append(librosa.load(audioread.ffdec.FFmpegAudioFile(path), sr=16000)[0])
elif path.startswith("file://"):
audios.append(librosa.load(path[len("file://") :], sr=16000)[0])
else:
audios.append(librosa.load(path, sr=16000)[0])
else:
raise ValueError("Unknown video {}".format(ele))
if len(audios) == 0:
audios = None
return audios
from __future__ import annotations
import base64
import logging
import math
import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
from typing import Optional
import requests
import torch
import torchvision
from packaging import version
from PIL import Image
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode
logger = logging.getLogger(__name__)
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
# Set the maximum number of video token inputs.
# Here, 128K represents the maximum number of input tokens for the VLLM model.
# Remember to adjust it according to your own configuration.
VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = to_rgb(image_obj)
## resize
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
nframes = total_frames / video_fps * fps
if nframes > total_frames:
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
return nframes
def _read_video_torchvision(
ele: dict,
) -> (torch.Tensor, float):
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
if "file://" in video_path:
video_path = video_path[7:]
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
video = video[idx]
return video, sample_fps
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
def _read_video_decord(
ele: dict,
) -> (torch.Tensor, float):
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
# TODO: support start_pts and end_pts
if 'video_start' in ele or 'video_end' in ele:
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
total_frames, video_fps = len(vr), vr.get_avg_fps()
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
return video, sample_fps
VIDEO_READER_BACKENDS = {
"decord": _read_video_decord,
"torchvision": _read_video_torchvision,
}
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
return video_reader_backend
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
try:
video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
except Exception as e:
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
if return_video_sample_fps:
return video, sample_fps
return video
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
if return_video_sample_fps:
return images, process_info.pop("fps", 2.0)
return images
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
vision_infos = extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
video_sample_fps_list.append(video_sample_fps)
video_inputs.append(video_input)
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return image_inputs, video_inputs, {'fps': video_sample_fps_list}
return image_inputs, video_inputs
\ No newline at end of file
# Core dependencies
gradio==5.23.1
gradio_client==1.8.0
qwen-omni-utils==0.0.3
librosa==0.11.0
ffmpeg==1.4
ffmpeg-python==0.2.0
soundfile==0.13.1
modelscope_studio==1.2.2
# git+https://github.com/huggingface/transformers@f742a644ca32e65758c3adb36225aef1731bd2a8
accelerate
av
qwen-vl-utils[decord]
# Optional dependency
# Uncomment the following line if you need flash-attn
# flash-attn==2.7.4.post1
# Core dependencies
gradio==5.23.1
gradio_client==1.8.0
qwen-omni-utils==0.0.3
librosa==0.11.0
ffmpeg==1.4
ffmpeg-python==0.2.0
soundfile==0.13.1
modelscope_studio==1.2.2
git+https://github.com/huggingface/transformers@f742a644ca32e65758c3adb36225aef1731bd2a8
accelerate
av
# Optional dependency
# Uncomment the following line if you need flash-attn
# flash-attn==2.7.4.post1
import io
import os
import ffmpeg
import numpy as np
import gradio as gr
import soundfile as sf
import modelscope_studio.components.base as ms
import modelscope_studio.components.antd as antd
import gradio.processing_utils as processing_utils
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
from gradio_client import utils as client_utils
from qwen_omni_utils import process_mm_info
from argparse import ArgumentParser
def _load_model_processor(args):
if args.cpu_only:
device_map = 'cpu'
else:
device_map = 'auto'
# Check if flash-attn2 flag is enabled and load model accordingly
if args.flash_attn2:
model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path,
torch_dtype='auto',
attn_implementation='flash_attention_2',
device_map=device_map)
else:
model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path, device_map=device_map)
processor = Qwen2_5OmniProcessor.from_pretrained(args.checkpoint_path)
return model, processor
def _launch_demo(args, model, processor):
# Voice settings
VOICE_LIST = ['Chelsie', 'Ethan']
DEFAULT_VOICE = 'Chelsie'
default_system_prompt = 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'
language = args.ui_language
def get_text(text: str, cn_text: str):
if language == 'en':
return text
if language == 'zh':
return cn_text
return text
def convert_webm_to_mp4(input_file, output_file):
try:
(
ffmpeg
.input(input_file)
.output(output_file, acodec='aac', ar='16000', audio_bitrate='192k')
.run(quiet=True, overwrite_output=True)
)
print(f"Conversion successful: {output_file}")
except ffmpeg.Error as e:
print("An error occurred during conversion.")
print(e.stderr.decode('utf-8'))
def format_history(history: list, system_prompt: str):
messages = []
messages.append({"role": "system", "content": system_prompt})
for item in history:
if isinstance(item["content"], str):
messages.append({"role": item['role'], "content": item['content']})
elif item["role"] == "user" and (isinstance(item["content"], list) or
isinstance(item["content"], tuple)):
file_path = item["content"][0]
mime_type = client_utils.get_mimetype(file_path)
if mime_type.startswith("image"):
messages.append({
"role":
item['role'],
"content": [{
"type": "image",
"image": file_path
}]
})
elif mime_type.startswith("video"):
messages.append({
"role":
item['role'],
"content": [{
"type": "video",
"video": file_path
}]
})
elif mime_type.startswith("audio"):
messages.append({
"role":
item['role'],
"content": [{
"type": "audio",
"audio": file_path,
}]
})
return messages
def predict(messages, voice=DEFAULT_VOICE):
print('predict history: ', messages)
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(messages, use_audio_in_video=True)
inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=True)
inputs = inputs.to(model.device).to(model.dtype)
text_ids, audio = model.generate(**inputs, spk=voice, use_audio_in_video=True)
response = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
response = response[0].split("\n")[-1]
yield {"type": "text", "data": response}
audio = np.array(audio * 32767).astype(np.int16)
wav_io = io.BytesIO()
sf.write(wav_io, audio, samplerate=24000, format="WAV")
wav_io.seek(0)
wav_bytes = wav_io.getvalue()
audio_path = processing_utils.save_bytes_to_cache(
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
yield {"type": "audio", "data": audio_path}
def media_predict(audio, video, history, system_prompt, voice_choice):
# First yield
yield (
None, # microphone
None, # webcam
history, # media_chatbot
gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn
)
if video is not None:
convert_webm_to_mp4(video, video.replace('.webm', '.mp4'))
video = video.replace(".webm", ".mp4")
files = [audio, video]
for f in files:
if f:
history.append({"role": "user", "content": (f, )})
formatted_history = format_history(history=history,
system_prompt=system_prompt,)
history.append({"role": "assistant", "content": ""})
for chunk in predict(formatted_history, voice_choice):
if chunk["type"] == "text":
history[-1]["content"] = chunk["data"]
yield (
None, # microphone
None, # webcam
history, # media_chatbot
gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn
)
if chunk["type"] == "audio":
history.append({
"role": "assistant",
"content": gr.Audio(chunk["data"])
})
# Final yield
yield (
None, # microphone
None, # webcam
history, # media_chatbot
gr.update(visible=True), # submit_btn
gr.update(visible=False), # stop_btn
)
def chat_predict(text, audio, image, video, history, system_prompt, voice_choice):
# Process text input
if text:
history.append({"role": "user", "content": text})
# Process audio input
if audio:
history.append({"role": "user", "content": (audio, )})
# Process image input
if image:
history.append({"role": "user", "content": (image, )})
# Process video input
if video:
history.append({"role": "user", "content": (video, )})
formatted_history = format_history(history=history,
system_prompt=system_prompt)
yield None, None, None, None, history
history.append({"role": "assistant", "content": ""})
for chunk in predict(formatted_history, voice_choice):
if chunk["type"] == "text":
history[-1]["content"] = chunk["data"]
yield gr.skip(), gr.skip(), gr.skip(), gr.skip(
), history
if chunk["type"] == "audio":
history.append({
"role": "assistant",
"content": gr.Audio(chunk["data"])
})
yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
with gr.Blocks() as demo, ms.Application(), antd.ConfigProvider():
with gr.Sidebar(open=False):
system_prompt_textbox = gr.Textbox(label="System Prompt",
value=default_system_prompt)
with antd.Flex(gap="small", justify="center", align="center"):
with antd.Flex(vertical=True, gap="small", align="center"):
antd.Typography.Title("Qwen2.5-Omni Demo",
level=1,
elem_style=dict(margin=0, fontSize=28))
with antd.Flex(vertical=True, gap="small"):
antd.Typography.Text(get_text("🎯 Instructions for use:",
"🎯 使用说明:"),
strong=True)
antd.Typography.Text(
get_text(
"1️⃣ Click the Audio Record button or the Camera Record button.",
"1️⃣ 点击音频录制按钮,或摄像头-录制按钮"))
antd.Typography.Text(
get_text("2️⃣ Input audio or video.", "2️⃣ 输入音频或者视频"))
antd.Typography.Text(
get_text(
"3️⃣ Click the submit button and wait for the model's response.",
"3️⃣ 点击提交并等待模型的回答"))
voice_choice = gr.Dropdown(label="Voice Choice",
choices=VOICE_LIST,
value=DEFAULT_VOICE)
with gr.Tabs():
with gr.Tab("Online"):
with gr.Row():
with gr.Column(scale=1):
microphone = gr.Audio(sources=['microphone'],
type="filepath")
webcam = gr.Video(sources=['webcam'],
height=400,
include_audio=True)
submit_btn = gr.Button(get_text("Submit", "提交"),
variant="primary")
stop_btn = gr.Button(get_text("Stop", "停止"), visible=False)
clear_btn = gr.Button(get_text("Clear History", "清除历史"))
with gr.Column(scale=2):
media_chatbot = gr.Chatbot(height=650, type="messages")
def clear_history():
return [], gr.update(value=None), gr.update(value=None)
submit_event = submit_btn.click(fn=media_predict,
inputs=[
microphone, webcam,
media_chatbot,
system_prompt_textbox,
voice_choice
],
outputs=[
microphone, webcam,
media_chatbot, submit_btn,
stop_btn
])
stop_btn.click(
fn=lambda:
(gr.update(visible=True), gr.update(visible=False)),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[submit_event],
queue=False)
clear_btn.click(fn=clear_history,
inputs=None,
outputs=[media_chatbot, microphone, webcam])
with gr.Tab("Offline"):
chatbot = gr.Chatbot(type="messages", height=650)
# Media upload section in one row
with gr.Row(equal_height=True):
audio_input = gr.Audio(sources=["upload"],
type="filepath",
label="Upload Audio",
elem_classes="media-upload",
scale=1)
image_input = gr.Image(sources=["upload"],
type="filepath",
label="Upload Image",
elem_classes="media-upload",
scale=1)
video_input = gr.Video(sources=["upload"],
label="Upload Video",
elem_classes="media-upload",
scale=1)
# Text input section
text_input = gr.Textbox(show_label=False,
placeholder="Enter text here...")
# Control buttons
with gr.Row():
submit_btn = gr.Button(get_text("Submit", "提交"),
variant="primary",
size="lg")
stop_btn = gr.Button(get_text("Stop", "停止"),
visible=False,
size="lg")
clear_btn = gr.Button(get_text("Clear History", "清除历史"),
size="lg")
def clear_chat_history():
return [], gr.update(value=None), gr.update(
value=None), gr.update(value=None), gr.update(value=None)
submit_event = gr.on(
triggers=[submit_btn.click, text_input.submit],
fn=chat_predict,
inputs=[
text_input, audio_input, image_input, video_input, chatbot,
system_prompt_textbox, voice_choice
],
outputs=[
text_input, audio_input, image_input, video_input, chatbot
])
stop_btn.click(fn=lambda:
(gr.update(visible=True), gr.update(visible=False)),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[submit_event],
queue=False)
clear_btn.click(fn=clear_chat_history,
inputs=None,
outputs=[
chatbot, text_input, audio_input, image_input,
video_input
])
# Add some custom CSS to improve the layout
gr.HTML("""
<style>
.media-upload {
margin: 10px;
min-height: 160px;
}
.media-upload > .wrap {
border: 2px dashed #ccc;
border-radius: 8px;
padding: 10px;
height: 100%;
}
.media-upload:hover > .wrap {
border-color: #666;
}
/* Make upload areas equal width */
.media-upload {
flex: 1;
min-width: 0;
}
</style>
""")
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
ssr_mode=False,
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,)
DEFAULT_CKPT_PATH = "Qwen/Qwen2.5-Omni-7B"
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default=DEFAULT_CKPT_PATH,
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
parser.add_argument('--ui-language', type=str, choices=['en', 'zh'], default='en', help='Display language for the UI.')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _get_args()
model, processor = _load_model_processor(args)
_launch_demo(args, model, processor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment