import torch
import os
import argparse
import numpy as np
import copy
import gradio as gr
import re
import torchaudio
import io
import ffmpeg
from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH
from vita.conversation import conv_templates, SeparatorStyle
from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token
from PIL import Image
from decord import VideoReader, cpu
from vllm import LLM, SamplingParams
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoFeatureExtractor
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
def remove_special_characters(input_str):
return input_str.replace('<2>', '').replace('<1>', '').replace('<3>', '')
def is_video(file_path):
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'}
_, ext = os.path.splitext(file_path)
return ext.lower() in video_extensions
def is_image(file_path):
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'}
_, ext = os.path.splitext(file_path)
return ext.lower() in image_extensions
def is_wav(file_path):
wav_extensions = {'.wav'}
_, ext = os.path.splitext(file_path)
return ext.lower() in wav_extensions
def convert_webm_to_mp4(input_file, output_file):
try:
(
ffmpeg
.input(input_file)
.output(output_file, vcodec='libx264', acodec='aac')
.run()
)
print(f"Conversion successful: {output_file}")
except ffmpeg.Error as e:
print(f"Error: {e.stderr.decode()}")
raise
def _get_rawvideo_dec(video_path, max_frames=MAX_IMAGE_LENGTH, min_frames=MIN_IMAGE_LENGTH, video_framerate=1, s=None, e=None):
if s is None or e is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = max(start_time, 0)
end_time = max(end_time, 0)
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = start_time + 1
if os.path.exists(video_path):
vreader = VideoReader(video_path, ctx=cpu(0))
else:
raise FileNotFoundError
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
sample_fps = int(video_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > max_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
elif len(all_pos) < min_frames:
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)]
else:
sample_pos = all_pos
patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()]
return patch_images, len(patch_images)
else:
print(f"video path: {video_path} error.")
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'
'
else:
lines[i] = "
"
else:
if i > 0 and count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "
" + line
return "".join(lines)
def _launch_demo(llm, model_config, sampling_params, tokenizer, feature_extractor):
def predict(_chatbot, task_history):
chat_query = task_history[-1][0]
print(task_history)
conv_mode = "mixtral_two"
conv = conv_templates[conv_mode].copy()
all_audio_path = []
all_visual_tensor = []
qs = ''
input_mode = 'lang'
for i, (q, a) in enumerate(task_history):
if isinstance(q, (tuple, list)):
if is_image(q[0]):
images = [Image.open(q[0]).convert("RGB")]
all_visual_tensor.extend(images)
input_mode = 'image'
qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n'
elif is_video(q[0]):
video_frames, slice_len = _get_rawvideo_dec(q[0])
all_visual_tensor.extend(video_frames)
input_mode = 'video'
qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n'
elif is_wav(q[0]):
if a is not None and a.startswith('<2>'):
continue
else:
all_audio_path.append(q[0])
new_q = qs + DEFAULT_AUDIO_TOKEN
qs = ''
conv.append_message(conv.roles[0], new_q)
conv.append_message(conv.roles[1], a)
else:
new_q = qs + q
qs = ''
conv.append_message(conv.roles[0], new_q)
conv.append_message(conv.roles[1], a)
print(conv)
prompt = conv.get_prompt(input_mode)
if all_audio_path != []:
input_ids = tokenizer_image_audio_token(
prompt, tokenizer,
image_token_index=model_config.image_token_index,
audio_token_index=model_config.audio_token_index
)
audio_list = []
for single_audio_path in all_audio_path:
try:
audio, original_sr = torchaudio.load(single_audio_path)
# The FeatureExtractor was trained using a sampling rate of 16000 Hz
target_sr = 16000
# Resample
if original_sr != target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
audio = resampler(audio)
audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"]
audio_list.append(audio_features.squeeze(0))
except Exception as e:
print(f"Error processing {single_audio_path}: {e}")
else:
input_ids = tokenizer_image_token(
prompt, tokenizer,
image_token_index=model_config.image_token_index
)
if all_visual_tensor == [] and all_audio_path == []:
datapromt={
"prompt_token_ids": input_ids,
}
elif all_visual_tensor != [] and all_audio_path == []:
datapromt={
"prompt_token_ids": input_ids,
"multi_modal_data": {
"image": all_visual_tensor
},
}
elif all_visual_tensor == [] and all_audio_path != []:
datapromt={
"prompt_token_ids": input_ids,
"multi_modal_data": {
"audio": audio_list
},
}
else:
datapromt={
"prompt_token_ids": input_ids,
"multi_modal_data": {
"image": all_visual_tensor,
"audio": audio_list
},
}
output = llm.generate(datapromt, sampling_params=sampling_params)
outputs = output[0].outputs[0].text
task_history[-1] = (chat_query, outputs)
remove_special_characters_output = remove_special_characters(outputs)
_chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output))
print("query",chat_query)
print("task_history",task_history)
print(_chatbot)
print("answer: ",outputs)
yield _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def add_file(history, task_history, file):
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def add_audio(history, task_history, file):
print(file)
if file is None:
return history, task_history
history = history + [((file,), None)]
task_history = task_history + [((file,), None)]
return history, task_history
def add_video(history, task_history, file):
print(file)
if file is None:
return history, task_history
new_file_name = file.replace(".webm",".mp4")
if file.endswith(".webm"):
convert_webm_to_mp4(file, new_file_name)
task_history = task_history + [((new_file_name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
with gr.Blocks(title="VideoMLLM") as demo:
gr.Markdown("""VITA""")
chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500)
query = gr.Textbox(lines=2, label='Text Input')
task_history = gr.State([])
with gr.Row():
add_text_button = gr.Button("Submit Text (提交文本)")
add_audio_button = gr.Button("Submit Audio (提交音频)")
with gr.Row():
with gr.Column(scale=2):
addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"])
video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)")
with gr.Column(scale=1):
empty_bin = gr.Button("🧹 Clear History (清除历史)")
record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
reset_user_input, [], [query]
).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True)
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
server_port = 18806
demo.launch(
share=False,
debug=True,
server_name="0.0.0.0",
server_port=server_port,
show_api=False,
show_error=False,
auth=('123','123'),
)
def main(model_path):
llm = LLM(
model=model_path,
dtype="float16",
tensor_parallel_size=2,
trust_remote_code=True,
gpu_memory_utilization=0.85,
disable_custom_all_reduce=True,
limit_mm_per_prompt={'image':256,'audio':50}
)
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.01, max_tokens=512, best_of=1, skip_special_tokens=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path, subfolder="feature_extractor", trust_remote_code=True)
_launch_demo(llm, model_config, sampling_params, tokenizer, feature_extractor)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run the web demo with your model path.')
parser.add_argument('model_path', type=str, help='Path to the model')
args = parser.parse_args()
main(args.model_path)