Commit 3d735feb authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #3074 canceled with stages
import json
import os
from tqdm import tqdm
from datasets import load_dataset
def validate_data(json_file_path, media_folder_path):
"""
Main function to validate JSON data by checking:
1. Media file existence (supports both image and video fields)
2. Media token consistency in conversations
Saves valid and problematic data to separate files
"""
# Validate input file format
if not json_file_path.endswith((".json", ".jsonl")):
print("Invalid file format. Please provide a .json or .jsonl file.")
return
# Prepare output file paths
base_path = os.path.splitext(json_file_path)[0]
valid_file_path = f"{base_path}_valid.json"
problem_file_path = f"{base_path}_problems.json"
# Load the dataset
try:
data = load_dataset("json", data_files=json_file_path)["train"]
except Exception as e:
print(f"Error loading dataset: {e}")
return
valid_data = []
problem_data = []
stats = {
'total_entries': 0,
'valid_entries': 0,
'missing_media': 0,
'token_mismatches': 0,
'gpt_media_tokens': 0,
'missing_files': [],
'media_types': {
'image': 0,
'video': 0,
'mixed': 0
}
}
print(f"Processing {len(data)} entries...")
for item in tqdm(data):
stats['total_entries'] += 1
problems = []
# Check media file existence (handle both singular and plural fields)
media_info = {
'image': item.get("image", item.get("images", [])),
'video': item.get("video", item.get("videos", []))
}
# Convert all media fields to lists
for media_type in media_info:
if isinstance(media_info[media_type], str):
media_info[media_type] = [media_info[media_type]]
elif not isinstance(media_info[media_type], list):
media_info[media_type] = []
# Count media types for stats
media_counts = {k: len(v) for k, v in media_info.items()}
active_media = [k for k, v in media_counts.items() if v > 0]
if len(active_media) > 1:
stats['media_types']['mixed'] += 1
elif len(active_media) == 1:
stats['media_types'][active_media[0]] += 1
# Check all media files exist
missing_files = []
for media_type, files in media_info.items():
for media_file in files:
media_path = os.path.join(media_folder_path, media_file)
if not os.path.exists(media_path):
missing_files.append(media_path)
if missing_files:
stats['missing_media'] += 1
stats['missing_files'].extend(missing_files)
problems.append({
'type': 'missing_files',
'files': missing_files,
'message': f"Missing media files: {missing_files}"
})
# Check media token consistency
conversations = item.get("conversations", [])
expected_counts = {
'image': media_counts['image'],
'video': media_counts['video']
}
actual_counts = {
'image': 0,
'video': 0
}
gpt_has_media_token = False
for conv in conversations:
if conv.get("from") == "human":
actual_counts['image'] += conv.get("value", "").count("<image>")
actual_counts['video'] += conv.get("value", "").count("<video>")
elif conv.get("from") == "gpt":
if "<image>" in conv.get("value", "") or "<video>" in conv.get("value", ""):
gpt_has_media_token = True
# Check token counts match media counts
for media_type in ['image', 'video']:
if actual_counts[media_type] != expected_counts[media_type]:
stats['token_mismatches'] += 1
problems.append({
'type': 'token_mismatch',
'media_type': media_type,
'expected': expected_counts[media_type],
'actual': actual_counts[media_type],
'message': f"Expected {expected_counts[media_type]} <{media_type}> tokens, found {actual_counts[media_type]}"
})
break # Count each entry only once for mismatches
if gpt_has_media_token:
stats['gpt_media_tokens'] += 1
problems.append({
'type': 'gpt_media_token',
'message': "GPT response contains media token (<image> or <video>)"
})
# Categorize the item
if not problems:
stats['valid_entries'] += 1
valid_data.append(item)
else:
problem_item = item.copy()
problem_item['validation_problems'] = problems
problem_data.append(problem_item)
# Save results
with open(valid_file_path, 'w') as f:
json.dump(valid_data, f, indent=2)
with open(problem_file_path, 'w') as f:
json.dump(problem_data, f, indent=2)
# Print summary
print("\nValidation Summary:")
print(f"Total entries processed: {stats['total_entries']}")
print(f"Valid entries: {stats['valid_entries']} ({stats['valid_entries']/stats['total_entries']:.1%})")
print(f"Media type distribution:")
print(f" - Image only: {stats['media_types']['image']}")
print(f" - Video only: {stats['media_types']['video']}")
print(f" - Mixed media: {stats['media_types']['mixed']}")
print(f"Entries with missing media: {stats['missing_media']}")
print(f"Entries with token mismatches: {stats['token_mismatches']}")
print(f"Entries with GPT media tokens: {stats['gpt_media_tokens']}")
if stats['missing_files']:
print("\nSample missing files (max 5):")
for f in stats['missing_files'][:5]:
print(f" - {f}")
# Example usage
if __name__ == "__main__":
json_file_path = "example.json" # Replace with your JSON file path
media_folder_path = "media" # Replace with your media folder path
validate_data(json_file_path, media_folder_path)
\ No newline at end of file
import json
import os
import numpy as np
from PIL import Image
from copy import deepcopy
from transformers import AutoTokenizer, Qwen2VLImageProcessor
from torchcodec.decoders import VideoDecoder
import binpacking
from tqdm import tqdm
import concurrent.futures
import time
def read_data(file_path):
"""Read JSON or JSONL file"""
if file_path.endswith(('.json', '.jsonl')):
with open(file_path, 'r') as f:
if file_path.endswith('.json'):
return json.load(f)
return [json.loads(line) for line in f]
raise ValueError('Please provide a .json or .jsonl file')
def write_data(file_path, data):
"""Write data to JSON or JSONL file"""
with open(file_path, 'w') as f:
if file_path.endswith('.json'):
json.dump(data, f, indent=4)
elif file_path.endswith('.jsonl'):
for item in data:
f.write(json.dumps(item) + '\n')
class DataArguments:
def __init__(self):
self.max_pixels = 2048 * 28 * 28
self.min_pixels = 32 * 28 * 28
self.video_max_frame_pixels = 576 * 28 * 28
self.video_min_frame_pixels = 144 * 28 * 28
self.base_interval = 4
self.video_min_frames = 4
self.video_max_frames = 8
self.data_path = ''
class MultimodalProcessor:
def __init__(self, data_args, base_processor, device='cpu'):
self.data_args = data_args
self.base_processor = base_processor
self.device = device
def _configure_processor(self, max_val, min_val):
processor = deepcopy(self.base_processor)
processor.max_pixels = max_val
processor.min_pixels = min_val
processor.size = {'longest_edge': max_val, 'shortest_edge': min_val}
return processor
def process_image(self, image_file):
image_path = os.path.join(self.data_args.data_path, image_file)
if not os.path.exists(image_path):
print(f'Image file does not exist: {image_path}')
return 0
processor = self._configure_processor(self.data_args.max_pixels, self.data_args.min_pixels)
image = Image.open(image_path).convert('RGB')
visual_processed = processor.preprocess(images=image, return_tensors='pt')
return visual_processed['image_grid_thw'].prod() // 4
def process_video(self, video_file):
video_path = os.path.join(self.data_args.data_path, video_file)
processor = self._configure_processor(self.data_args.video_max_frame_pixels, self.data_args.video_min_frame_pixels)
decoder = VideoDecoder(video_path, device=self.device)
total_frames = decoder.metadata.num_frames
avg_fps = decoder.metadata.average_fps
video_length = total_frames / avg_fps
interval = self.data_args.base_interval
num_frames_to_sample = round(video_length / interval)
target_frames = min(max(num_frames_to_sample, self.data_args.video_min_frames), self.data_args.video_max_frames)
frame_idx = np.unique(np.linspace(0, total_frames - 1, target_frames, dtype=int)).tolist()
frame_batch = decoder.get_frames_at(indices=frame_idx)
video_frames_numpy = frame_batch.data.cpu().numpy()
visual_processed = processor.preprocess(images=None, videos=video_frames_numpy, return_tensors='pt')
return visual_processed['video_grid_thw'].prod() // 4
def calculate_tokens(conversation, processor, tokenizer):
total_tokens = 21
roles = {'human': 'user', 'gpt': 'assistant'}
for message in conversation['conversations']:
role = message['from']
text = message['value']
conv = [{'role': roles[role], 'content': text}]
encode_id = tokenizer.apply_chat_template(conv, return_tensors='pt', add_generation_prompt=False)[0]
total_tokens += len(encode_id)
if 'image' in conversation:
images = conversation['image'] if isinstance(conversation['image'], list) else [conversation['image']]
for image_file in images:
total_tokens += processor.process_image(image_file)
elif 'video' in conversation:
videos = conversation['video'] if isinstance(conversation['video'], list) else [conversation['video']]
for video_file in videos:
total_tokens += processor.process_video(video_file)
return total_tokens
def pack_data(data_list, pack_length):
# Extract the length of each data item
lengths = [data["num_tokens"] for data in data_list]
grouped_indices = binpacking.to_constant_volume(
list(enumerate(lengths)), # Explicitly convert to list
pack_length,
weight_pos=1
)
packed_data = []
for group in grouped_indices:
group_data = []
for index, _ in group:
new_data = data_list[index].copy()
new_data.pop("num_tokens", None)
group_data.append(new_data)
packed_data.append(group_data)
return packed_data
datasets = {
'dummy_dataset': {
'data_path': '',
'annotation_path': 'path/to/your/annotation.json'
}
}
data_args = DataArguments()
model_path = 'path/to/your/model'
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
base_image_processor = Qwen2VLImageProcessor.from_pretrained(model_path)
print(f'Successfully loaded model components from {model_path}')
processor = MultimodalProcessor(data_args, base_image_processor, device='cpu')
for dataset_name, config in datasets.items():
processor.data_args.data_path = config['data_path']
annotation_path = os.path.join(processor.data_args.data_path, config['annotation_path'])
print(f'\n--- Processing dataset: {dataset_name} ---')
print(f'Annotation file path: {annotation_path}')
print(f'Image configuration: max_pixels={data_args.max_pixels}, min_pixels={data_args.min_pixels}')
print(f'Video frame configuration: video_max_frame_pixels={data_args.video_max_frame_pixels}, video_min_frame_pixels={data_args.video_min_frame_pixels}')
if not os.path.exists(annotation_path):
print(f'Annotation file not found: {annotation_path}')
continue
data = read_data(annotation_path)
count_file_path = annotation_path.replace('.jsonl', '_count.json').replace('.json', '_count.json')
if os.path.exists(count_file_path):
print(f"Found pre - calculated token counts, loading data from {count_file_path}.")
data_with_tokens = read_data(count_file_path)
else:
def calculate_and_update(item):
item['num_tokens'] = calculate_tokens(item, processor, tokenizer)
return item
with concurrent.futures.ThreadPoolExecutor() as executor:
data_with_tokens = list(tqdm(executor.map(calculate_and_update, data), total=len(data), desc=f"Processing {dataset_name} data"))
# Save the token count results
write_data(count_file_path, data_with_tokens)
print(f"Token counts saved to: {count_file_path}")
# Assume the packing length is 4096
pack_length = 4096
# Define the batch size
batch_size = 256
all_packed_results = []
# Record the start time of binpacking
start_time = time.time()
for i in range(0, len(data_with_tokens), batch_size):
batch_data = data_with_tokens[i: i + batch_size]
batch_packed_result = pack_data(batch_data, pack_length)
all_packed_results.extend(batch_packed_result)
# Record the end time of binpacking
end_time = time.time()
# Calculate the time spent on binpacking
binpack_time = end_time - start_time
print(f"Time spent on binpacking: {binpack_time:.4f} seconds")
# Save the packed results as a JSON file
pack_output_path = annotation_path.replace('.jsonl', '_pack.json').replace('.json', '_pack.json')
with open(pack_output_path, 'w', encoding='utf-8') as file:
json.dump(all_packed_results, file, indent=2)
print(f"Packed results saved to: {pack_output_path}")
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
# qwen-vl-utils
Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
## Install
```bash
pip install qwen-vl-utils
```
## Usage
### Qwen2VL
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos = process_vision_info(messages)
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
print(inputs)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
### Qwen2.5VL
```python
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
# based on the maximum tokens that the model can accept.
# export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
print(inputs)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
### Qwen3VL
```python
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
messages = [
# Image
## Local file path
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Image URL
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
## Base64 encoded image
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
## PIL.Image.Image
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
## Model dynamically adjusts image size, specify dimensions if required.
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
# Video
## Local video path
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
## Local video frames
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
## Model dynamically adjusts video nframes, video height and width. specify args if required.
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
]
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen3VLForConditionalGeneration.from_pretrained(model_path, dtype="auto", device_map="auto")
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos, video_kwargs = process_vision_info(messages, image_patch_size=16, return_video_kwargs=True, return_video_metadata=True)
if videos is not None:
videos, video_metadatas = zip(*videos)
videos, video_metadatas = list(videos), list(video_metadatas)
else:
video_metadatas = None
inputs = processor(text=text, images=images, videos=videos, video_metadata=video_metadatas, return_tensors="pt", do_resize=False, **video_kwargs)
inputs = inputs.to(model.device)
generated_ids = model.generate(**inputs)
print(generated_ids)
```
\ No newline at end of file
[project]
name = "qwen-vl-utils"
version = "0.0.14"
description = "Qwen Vision Language Model Utils - PyTorch"
authors = [
{ name = "Qwen Team", email = "chenkeqin.ckq@alibaba-inc.com" },
]
dependencies = [
"requests",
"pillow",
"av",
"packaging",
]
readme = "README.md"
requires-python = ">= 3.8"
license = {text = "Apache-2.0"}
keywords = [
'large language model',
'vision language model',
'qwen-vl',
'pytorch',
]
classifiers = [
'Development Status :: 4 - Beta',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
]
[project.urls]
Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
Repository = "https://github.com/QwenLM/Qwen2-VL.git"
Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
[project.optional-dependencies]
decord = [
"decord",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"torch",
"torchvision",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/qwen_vl_utils"]
[tool.ruff]
line-length = 119
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["qwen_vl_utils"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: ["decord"]
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
av==12.3.0
# via qwen-vl-utils
certifi==2022.12.7
# via requests
charset-normalizer==2.1.1
# via requests
decord==0.6.0
# via qwen-vl-utils
filelock==3.13.1
# via torch
# via triton
fsspec==2024.2.0
# via torch
idna==3.4
# via requests
jinja2==3.1.3
# via torch
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.1
# via torch
numpy==1.24.1
# via decord
# via torchvision
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.6.68
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
packaging==24.1
# via qwen-vl-utils
pillow==10.2.0
# via qwen-vl-utils
# via torchvision
requests==2.28.1
# via qwen-vl-utils
sympy==1.12
# via torch
torch==2.4.0
# via torchvision
torchvision==0.19.0
triton==3.0.0
# via torch
typing-extensions==4.9.0
# via torch
urllib3==1.26.13
# via requests
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: ["decord"]
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
av==12.3.0
# via qwen-vl-utils
certifi==2022.12.7
# via requests
charset-normalizer==2.1.1
# via requests
decord==0.6.0
# via qwen-vl-utils
idna==3.4
# via requests
numpy==1.24.4
# via decord
packaging==24.1
# via qwen-vl-utils
pillow==10.2.0
# via qwen-vl-utils
requests==2.28.1
# via qwen-vl-utils
urllib3==1.26.13
# via requests
from .vision_process import (
extract_vision_info,
fetch_image,
fetch_video,
process_vision_info,
smart_resize,
)
import base64
import copy
import logging
import math
import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
from typing import Optional, Union, Tuple, List, Any, Dict
from concurrent.futures import ThreadPoolExecutor
import requests
import torch
import torchvision
from packaging import version
from PIL import Image
import numpy as np
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode
MAX_RATIO = 200
SPATIAL_MERGE_SIZE = 2
IMAGE_MIN_TOKEN_NUM = 4
IMAGE_MAX_TOKEN_NUM = 16384
VIDEO_MIN_TOKEN_NUM = 128
VIDEO_MAX_TOKEN_NUM = 768
FPS = 2.0
FRAME_FACTOR = 2
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
MAX_NUM_WORKERS_FETCH_VIDEO = 8
MODEL_SEQ_LEN = int(float(os.environ.get('MODEL_SEQ_LEN', 128000)))
logger = logging.getLogger(__name__)
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(height: int, width: int, factor: int, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None) -> Tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
max_pixels = max_pixels if max_pixels is not None else (IMAGE_MAX_TOKEN_NUM * factor ** 2)
min_pixels = min_pixels if min_pixels is not None else (IMAGE_MIN_TOKEN_NUM * factor ** 2)
assert max_pixels >= min_pixels, "The max_pixels of image must be greater than or equal to min_pixels."
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: Dict[str, Union[str, Image.Image]], image_patch_size: int = 14) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
patch_factor = int(image_patch_size * SPATIAL_MERGE_SIZE)
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = to_rgb(image_obj)
## resize
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=patch_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", IMAGE_MIN_TOKEN_NUM * patch_factor ** 2)
max_pixels = ele.get("max_pixels", IMAGE_MAX_TOKEN_NUM * patch_factor ** 2)
resized_height, resized_width = smart_resize(
height,
width,
factor=patch_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: Dict[str, Any],
total_frames: int,
video_fps: Union[int, float],
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
nframes = total_frames / video_fps * fps
if nframes > total_frames:
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
return nframes
def _read_video_torchvision(
ele: Dict[str, Any],
) -> Tuple[torch.Tensor, float]:
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
if "file://" in video_path:
video_path = video_path[7:]
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
video = video[idx]
video_metadata = dict(
fps=video_fps,
frames_indices=idx,
total_num_frames=total_frames,
video_backend="torchvision",
)
return video, video_metadata, sample_fps
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
def calculate_video_frame_range(
ele: Dict[str, Any],
total_frames: int,
video_fps: float,
) -> Tuple[int, int, int]:
"""
Calculate the start and end frame indices based on the given time range.
Args:
ele (dict): A dictionary containing optional 'video_start' and 'video_end' keys (in seconds).
total_frames (int): Total number of frames in the video.
video_fps (float): Frames per second of the video.
Returns:
tuple: A tuple containing (start_frame, end_frame, frame_count).
Raises:
ValueError: If input parameters are invalid or the time range is inconsistent.
"""
# Validate essential parameters
if video_fps <= 0:
raise ValueError("video_fps must be a positive number")
if total_frames <= 0:
raise ValueError("total_frames must be a positive integer")
# Get start and end time in seconds
video_start = ele.get("video_start", None)
video_end = ele.get("video_end", None)
if video_start is None and video_end is None:
return 0, total_frames - 1, total_frames
max_duration = total_frames / video_fps
# Process start frame
if video_start is not None:
video_start_clamped = max(0.0, min(video_start, max_duration))
start_frame = math.ceil(video_start_clamped * video_fps)
else:
start_frame = 0
# Process end frame
if video_end is not None:
video_end_clamped = max(0.0, min(video_end, max_duration))
end_frame = math.floor(video_end_clamped * video_fps)
end_frame = min(end_frame, total_frames - 1)
else:
end_frame = total_frames - 1
# Validate frame order
if start_frame >= end_frame:
raise ValueError(
f"Invalid time range: Start frame {start_frame} (at {video_start_clamped if video_start is not None else 0}s) "
f"exceeds end frame {end_frame} (at {video_end_clamped if video_end is not None else max_duration}s). "
f"Video duration: {max_duration:.2f}s ({total_frames} frames @ {video_fps}fps)"
)
logger.info(f"calculate video frame range: {start_frame=}, {end_frame=}, {total_frames=} from {video_start=}, {video_end=}, {video_fps=:.3f}")
return start_frame, end_frame, end_frame - start_frame + 1
def _read_video_decord(
ele: Dict[str, Any],
) -> Tuple[torch.Tensor, float]:
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
total_frames, video_fps = len(vr), vr.get_avg_fps()
start_frame, end_frame, total_frames = calculate_video_frame_range(
ele,
total_frames,
video_fps,
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
video_metadata = dict(
fps=video_fps,
frames_indices=idx,
total_num_frames=total_frames,
video_backend="decord",
)
return video, video_metadata, sample_fps
def is_torchcodec_available() -> bool:
import importlib.util
return importlib.util.find_spec("torchcodec") is not None
def _read_video_torchcodec(
ele: Dict[str, Any],
) -> Tuple[torch.Tensor, float]:
"""read video using torchcodec.decoders.VideoDecoder
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
from torchcodec.decoders import VideoDecoder
TORCHCODEC_NUM_THREADS = int(os.environ.get('TORCHCODEC_NUM_THREADS', 8))
logger.info(f"set TORCHCODEC_NUM_THREADS: {TORCHCODEC_NUM_THREADS}")
video_path = ele["video"]
st = time.time()
decoder = VideoDecoder(video_path, num_ffmpeg_threads=TORCHCODEC_NUM_THREADS)
video_fps = decoder.metadata.average_fps
total_frames = decoder.metadata.num_frames
start_frame, end_frame, total_frames = calculate_video_frame_range(
ele,
total_frames,
video_fps,
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist()
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
video = decoder.get_frames_at(indices=idx).data
logger.info(f"torchcodec: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
video_metadata = dict(
fps=video_fps,
frames_indices=idx,
total_num_frames=total_frames,
video_backend="torchcodec",
)
return video, video_metadata, sample_fps
VIDEO_READER_BACKENDS = {
"decord": _read_video_decord,
"torchvision": _read_video_torchvision,
"torchcodec": _read_video_torchcodec,
}
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_torchcodec_available():
video_reader_backend = "torchcodec"
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
return video_reader_backend
def fetch_video(ele: Dict[str, Any], image_patch_size: int = 14, return_video_sample_fps: bool = False,
return_video_metadata: bool = False) -> Union[torch.Tensor, List[Image.Image]]:
image_factor = image_patch_size * SPATIAL_MERGE_SIZE
VIDEO_FRAME_MIN_PIXELS = VIDEO_MIN_TOKEN_NUM * image_factor * image_factor
VIDEO_FRAME_MAX_PIXELS = VIDEO_MAX_TOKEN_NUM * image_factor * image_factor
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
try:
video, video_metadata, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
except Exception as e:
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
video, video_metadata, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
else:
# The input is a list of frames
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
# use ThreadPoolExecutor to parallel process frames
max_workers = min(MAX_NUM_WORKERS_FETCH_VIDEO, len(ele["video"]))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(fetch_image, {"image": video_element, **process_info}, image_factor)
for video_element in ele["video"]
]
image_list = [future.result() for future in futures]
nframes = ceil_by_factor(len(image_list), FRAME_FACTOR)
if len(image_list) < nframes:
image_list.extend([image_list[-1]] * (nframes - len(image_list)))
sample_fps = ele.get("sample_fps", 2.0)
video = torch.stack([
torch.from_numpy(np.array(image).transpose(2, 0, 1))
for image in image_list
])
# fake video metadata
raw_fps = process_info.pop("raw_fps", sample_fps)
video_metadata = dict(
fps=raw_fps,
frames_indices=[i for i in range(len(video))],
total_num_frames=(nframes / sample_fps) * raw_fps,
)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_FRAME_MIN_PIXELS)
total_pixels = ele.get("total_pixels", MODEL_SEQ_LEN * image_factor * image_factor * 0.9)
max_pixels = max(min(VIDEO_FRAME_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
final_video = (video, video_metadata) if return_video_metadata else video
if return_video_sample_fps:
return final_video, sample_fps
return final_video
def extract_vision_info(conversations: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]) -> List[Dict[str, Any]]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele.get("type", "text") in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
return_video_kwargs: bool = False,
return_video_metadata: bool = False,
image_patch_size: int = 14,
) -> Tuple[Optional[List[Image.Image]], Optional[List[Union[torch.Tensor, List[Image.Image]]]], Optional[Dict[str, Any]]]:
vision_infos = extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info, image_patch_size=image_patch_size))
elif "video" in vision_info:
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True,
image_patch_size=image_patch_size, return_video_metadata=return_video_metadata)
video_sample_fps_list.append(video_sample_fps)
video_inputs.append(video_input)
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
video_kwargs = {'do_sample_frames': False}
if not return_video_metadata: # BC for qwen2.5vl
video_kwargs.update({'fps': video_sample_fps_list})
if return_video_kwargs:
return image_inputs, video_inputs, video_kwargs
return image_inputs, video_inputs
\ No newline at end of file
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
# default: Load the model on the available device(s)
model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-8B-Instruct", dtype="auto", device_map="auto"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen3VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen3-VL-8B-Instruct",
# dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "./doc/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
]
# Preparation for inference
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
# default: Load the model on the available device(s)
model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-8B-Instruct", dtype="auto", device_map="auto"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen3VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen3-VL-8B-Instruct",
# dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
# Messages containing multiple images and a text query
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": "./doc/demo.jpeg"},
{"type": "image", "image": "./doc/dog.jpg"},
{"type": "text", "text": "Identify the similarities between these images."},
],
}
]
# Preparation for inference
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
import torch
# default: Load the model on the available device(s)
model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto",attn_implementation="flash_attention_2"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen3VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen3-VL-8B-Instruct",
# dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": "./doc/space_woaudio.mp4",
"max_pixels": 300 * 300,
"fps": 1.0,
},
{"type": "text", "text": "Describe this video."},
],
}
]
# Preparation for inference
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
# Core dependencies
gradio==5.49.1
gradio_client==1.13.3
transformers-stream-generator==0.0.5
transformers==4.57.1
#torch
#torchvision
accelerate
av==16.0.1
# Optional dependency
# Uncomment the following line if you need flash-attn
# flash-attn
\ No newline at end of file
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import copy
import re
from argparse import ArgumentParser
from threading import Thread
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
try:
from vllm import SamplingParams, LLM
from qwen_vl_utils import process_vision_info
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
print("Warning: vLLM not available. Install vllm and qwen-vl-utils to use vLLM backend.")
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default='Qwen/Qwen3-VL-235B-A22B-Instruct',
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
parser.add_argument('--backend',
type=str,
choices=['hf', 'vllm'],
default='vllm',
help='Backend to use: hf (HuggingFace) or vllm (vLLM)')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.70,
help='GPU memory utilization for vLLM (default: 0.70)')
parser.add_argument('--tensor-parallel-size',
type=int,
default=None,
help='Tensor parallel size for vLLM (default: auto)')
args = parser.parse_args()
return args
def _load_model_processor(args):
if args.backend == 'vllm':
if not VLLM_AVAILABLE:
raise ImportError("vLLM is not available. Please install vllm and qwen-vl-utils.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
tensor_parallel_size = args.tensor_parallel_size
if tensor_parallel_size is None:
tensor_parallel_size = torch.cuda.device_count()
# Initialize vLLM sync engine
model = LLM(
model=args.checkpoint_path,
trust_remote_code=True,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=False,
tensor_parallel_size=tensor_parallel_size,
seed=0
)
# Load processor for vLLM
processor = AutoProcessor.from_pretrained(args.checkpoint_path)
return model, processor, 'vllm'
else:
if args.cpu_only:
device_map = 'cpu'
else:
device_map = 'auto'
# Check if flash-attn2 flag is enabled and load model accordingly
if args.flash_attn2:
model = AutoModelForImageTextToText.from_pretrained(args.checkpoint_path,
torch_dtype='auto',
attn_implementation='flash_attention_2',
device_map=device_map)
else:
model = AutoModelForImageTextToText.from_pretrained(args.checkpoint_path, device_map=device_map)
processor = AutoProcessor.from_pretrained(args.checkpoint_path)
return model, processor, 'hf'
def _parse_text(text):
lines = text.split('\n')
lines = [line for line in lines if line != '']
count = 0
for i, line in enumerate(lines):
if '```' in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = '<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace('`', r'\`')
line = line.replace('<', '&lt;')
line = line.replace('>', '&gt;')
line = line.replace(' ', '&nbsp;')
line = line.replace('*', '&ast;')
line = line.replace('_', '&lowbar;')
line = line.replace('-', '&#45;')
line = line.replace('.', '&#46;')
line = line.replace('!', '&#33;')
line = line.replace('(', '&#40;')
line = line.replace(')', '&#41;')
line = line.replace('$', '&#36;')
lines[i] = '<br>' + line
text = ''.join(lines)
return text
def _remove_image_special(text):
text = text.replace('<ref>', '').replace('</ref>', '')
return re.sub(r'<box>.*?(</box>|$)', '', text)
def _is_video_file(filename):
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
return any(filename.lower().endswith(ext) for ext in video_extensions)
def _gc():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _transform_messages(original_messages):
transformed_messages = []
for message in original_messages:
new_content = []
for item in message['content']:
if 'image' in item:
new_item = {'type': 'image', 'image': item['image']}
elif 'text' in item:
new_item = {'type': 'text', 'text': item['text']}
elif 'video' in item:
new_item = {'type': 'video', 'video': item['video']}
else:
continue
new_content.append(new_item)
new_message = {'role': message['role'], 'content': new_content}
transformed_messages.append(new_message)
return transformed_messages
def _prepare_inputs_for_vllm(messages, processor):
"""Prepare inputs for vLLM inference"""
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, video_kwargs = process_vision_info(
messages,
image_patch_size=processor.image_processor.patch_size,
return_video_kwargs=True,
return_video_metadata=True
)
mm_data = {}
if image_inputs is not None:
mm_data['image'] = image_inputs
if video_inputs is not None:
mm_data['video'] = video_inputs
return {
'prompt': text,
'multi_modal_data': mm_data,
'mm_processor_kwargs': video_kwargs
}
def _launch_demo(args, model, processor, backend):
def call_local_model(model, processor, messages, backend):
messages = _transform_messages(messages)
if backend == 'vllm':
# vLLM inference
inputs = _prepare_inputs_for_vllm(messages, processor)
sampling_params = SamplingParams(max_tokens=1024)
accumulated_text = ''
for output in model.generate(inputs, sampling_params=sampling_params):
for completion in output.outputs:
new_text = completion.text
if new_text:
accumulated_text += new_text
yield accumulated_text
else:
# HuggingFace inference
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
tokenizer = processor.tokenizer
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_kwargs = {'max_new_tokens': 1024, 'streamer': streamer, **inputs}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
generated_text = ''
for new_text in streamer:
generated_text += new_text
yield generated_text
def create_predict_fn():
def predict(_chatbot, task_history):
nonlocal model, processor, backend
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print('User: ' + _parse_text(query))
history_cp = copy.deepcopy(task_history)
full_response = ''
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
if _is_video_file(q[0]):
content.append({'video': f'{os.path.abspath(q[0])}'})
else:
content.append({'image': f'{os.path.abspath(q[0])}'})
else:
content.append({'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'text': a}]})
content = []
messages.pop()
for response in call_local_model(model, processor, messages, backend):
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
yield _chatbot
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print('Qwen-VL-Chat: ' + _parse_text(full_response))
yield _chatbot
return predict
def create_regenerate_fn():
def regenerate(_chatbot, task_history):
nonlocal model, processor, backend
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
return regenerate
predict = create_predict_fn()
regenerate = create_regenerate_fn()
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ''
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value='')
def reset_state(_chatbot, task_history):
task_history.clear()
_chatbot.clear()
_gc()
return []
with gr.Blocks() as demo:
gr.Markdown("""\
<p align="center"><img src="https://qianwen-res.oss-accelerate.aliyuncs.com/Qwen3-VL/qwen3vllogo.png" style="height: 80px"/><p>"""
)
gr.Markdown("""<center><font size=8>Qwen3-VL</center>""")
gr.Markdown(f"""\
<center><font size=3>This WebUI is based on Qwen3-VL, developed by Alibaba Cloud. Backend: {backend.upper()}</center>""")
gr.Markdown(f"""<center><font size=3>本 WebUI 基于 Qwen3-VL。</center>""")
chatbot = gr.Chatbot(label='Qwen3-VL', elem_classes='control-height', height=500)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton('📁 Upload (上传文件)', file_types=['image', 'video'])
submit_btn = gr.Button('🚀 Submit (发送)')
regen_btn = gr.Button('🤔️ Regenerate (重试)')
empty_bin = gr.Button('🧹 Clear History (清除历史)')
submit_btn.click(add_text, [chatbot, task_history, query],
[chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
gr.Markdown("""\
<font size=2>Note: This demo is governed by the original license of Qwen3-VL. \
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
including hate speech, violence, pornography, deception, etc. \
(注:本演示受 Qwen3-VL 的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
demo.queue().launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def main():
args = _get_args()
model, processor, backend = _load_model_processor(args)
_launch_demo(args, model, processor, backend)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment