Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
# from .demo_modelpart import InferenceDemo
import gradio as gr
import os
# import time
import cv2
# import copy
import torch
# import random
import numpy as np
from llava import conversation as conversation_lib
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
class InferenceDemo(object):
def __init__(self,args,model_path,tokenizer, model, image_processor, context_len) -> None:
disable_torch_init()
self.tokenizer, self.model, self.image_processor, self.context_len = tokenizer, model, image_processor, context_len
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
elif 'qwen' in model_name.lower():
conv_mode = "qwen_1_5"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
self.conv_mode=conv_mode
self.conversation = conv_templates[args.conv_mode].copy()
self.num_frames = args.num_frames
def is_valid_video_filename(name):
video_extensions = ['avi', 'mp4', 'mov', 'mkv', 'flv', 'wmv', 'mjpeg']
ext = name.split('.')[-1].lower()
if ext in video_extensions:
return True
else:
return False
def sample_frames(video_file, num_frames) :
video = cv2.VideoCapture(video_file)
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
interval = total_frames // num_frames
frames = []
for i in range(total_frames):
ret, frame = video.read()
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if not ret:
continue
if i % interval == 0:
frames.append(pil_img)
video.release()
return frames
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
if response.status_code == 200:
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
print('failed to load the image')
else:
print('Load image from local file')
print(image_file)
image = Image.open(image_file).convert("RGB")
return image
def clear_history(history):
our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
return None
def clear_response(history):
for index_conv in range(1, len(history)):
# loop until get a text response from our model.
conv = history[-index_conv]
if not (conv[0] is None):
break
question = history[-index_conv][0]
history = history[:-index_conv]
return history, question
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def add_message(history, message):
# history=[]
global our_chatbot
if len(history)==0:
our_chatbot = InferenceDemo(args,model_path,tokenizer, model, image_processor, context_len)
for x in message["files"]:
history.append(((x,), None))
if message["text"] is not None:
history.append((message["text"], None))
return history, gr.MultimodalTextbox(value=None, interactive=False)
def bot(history):
text=history[-1][0]
images_this_term=[]
text_this_term=''
# import pdb;pdb.set_trace()
num_new_images = 0
for i,message in enumerate(history[:-1]):
if type(message[0]) is tuple:
images_this_term.append(message[0][0])
if is_valid_video_filename(message[0][0]):
num_new_images+=our_chatbot.num_frames
else:
num_new_images+=1
else:
num_new_images=0
# for message in history[-i-1:]:
# images_this_term.append(message[0][0])
assert len(images_this_term)>0, "must have an image"
# image_files = (args.image_file).split(',')
# image = [load_image(f) for f in images_this_term if f]
image_list=[]
for f in images_this_term:
if is_valid_video_filename(f):
image_list+=sample_frames(f, our_chatbot.num_frames)
else:
image_list.append(load_image(f))
image_tensor = [our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0].half().to(our_chatbot.model.device) for f in image_list]
image_tensor = torch.stack(image_tensor)
image_token = DEFAULT_IMAGE_TOKEN*num_new_images
# if our_chatbot.model.config.mm_use_im_start_end:
# inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
# else:
inp=text
inp = image_token+ "\n" + inp
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
# image = None
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
prompt = our_chatbot.conversation.get_prompt()
input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, our_chatbot.tokenizer, input_ids)
streamer = TextStreamer(our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
# import pdb;pdb.set_trace()
with torch.inference_mode():
output_ids = our_chatbot.model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=False, stopping_criteria=[stopping_criteria])
outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
our_chatbot.conversation.messages[-1][-1] = outputs
history[-1]=[text,outputs]
return history
txt = gr.Textbox(
scale=4,
show_label=False,
placeholder="Enter text and press enter.",
container=False,
)
with gr.Blocks() as demo:
# Informations
title_markdown = ("""
# LLaVA-NeXT Interleave
[[Blog]](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/) [[Code]](https://github.com/LLaVA-VL/LLaVA-NeXT) [[Model]](https://huggingface.co/lmms-lab/llava-next-interleave-7b)
""")
tos_markdown = ("""
### TODO!. Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
""")
learn_more_markdown = ("""
### TODO!. License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
models = [
"LLaVA-Interleave-7B",
]
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Markdown(title_markdown)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False
)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image","video"], placeholder="Enter message or upload file...", show_label=False)
with gr.Row():
upvote_btn = gr.Button(value="Upvote", interactive=True)
downvote_btn = gr.Button(value=" Downvote", interactive=True)
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
regenerate_btn = gr.Button(value="Regenerate", interactive=True)
clear_btn = gr.Button(value=" Clear history", interactive=True)
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
chatbot.like(print_like_dislike, None, None)
clear_btn.click(fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all")
with gr.Column():
gr.Examples(examples=[
[{"files": [f"{cur_dir}/examples/code1.jpeg",f"{cur_dir}/examples/code2.jpeg"], "text": "Please pay attention to the movement of the object from the first image to the second image, then write a HTML code to show this movement."}],
[{"files": [f"{cur_dir}/examples/shub.jpg",f"{cur_dir}/examples/shuc.jpg",f"{cur_dir}/examples/shud.jpg"], "text": "what is fun about the images?"}],
[{"files": [f"{cur_dir}/examples/iphone-15-price-1024x576.jpg",f"{cur_dir}/examples/dynamic-island-1024x576.jpg",f"{cur_dir}/examples/iphone-15-colors-1024x576.jpg",f"{cur_dir}/examples/Iphone-15-Usb-c-charger-1024x576.jpg",f"{cur_dir}/examples/A-17-processors-1024x576.jpg"], "text": "The images are the PPT of iPhone 15 review. can you summarize the main information?"}],
[{"files": [f"{cur_dir}/examples/fangao3.jpeg",f"{cur_dir}/examples/fangao2.jpeg",f"{cur_dir}/examples/fangao1.jpeg"], "text": "Do you kown who draw these paintings?"}],
[{"files": [f"{cur_dir}/examples/oprah-winfrey-resume.png",f"{cur_dir}/examples/steve-jobs-resume.jpg"], "text": "Hi, there are two candidates, can you provide a brief description for each of them for me?"}],
[{"files": [f"{cur_dir}/examples/original_bench.jpeg",f"{cur_dir}/examples/changed_bench.jpeg"], "text": "How to edit image1 to make it look like image2?"}],
[{"files": [f"{cur_dir}/examples/twitter2.jpeg",f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
[{"files": [f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
# [{"files": [f"playground/demo/examples/lion1_.mp4",f"playground/demo/examples/lion2_.mp4"], "text": "The input contains two videos, the first half is the first video and the second half is the second video. What is the difference between the two videos?"}],
], inputs=[chat_input], label="Compare images: ")
demo.queue()
if __name__ == "__main__":
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument("--server_name", default="0.0.0.0", type=str)
argparser.add_argument("--port", default="6123", type=str)
argparser.add_argument("--model_path", default="", type=str)
# argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
argparser.add_argument("--model-base", type=str, default=None)
argparser.add_argument("--num-gpus", type=int, default=1)
argparser.add_argument("--conv-mode", type=str, default=None)
argparser.add_argument("--temperature", type=float, default=0.2)
argparser.add_argument("--max-new-tokens", type=int, default=512)
argparser.add_argument("--num_frames", type=int, default=16)
argparser.add_argument("--load-8bit", action="store_true")
argparser.add_argument("--load-4bit", action="store_true")
argparser.add_argument("--debug", action="store_true")
args = argparser.parse_args()
model_path = args.model_path
filt_invalid="cut"
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
our_chatbot = None
# import pdb;pdb.set_trace()
try:
demo.launch(server_name=args.server_name, server_port=int(args.port),share=True)
except Exception as e:
args.port=int(args.port)+1
print(f"Port {args.port} is occupied, try port {args.port}")
demo.launch(server_name=args.server_name, server_port=int(args.port),share=True)
import argparse
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_anyres_image,tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
import json
import os
import math
from tqdm import tqdm
from decord import VideoReader, cpu
from transformers import AutoConfig
import cv2
import base64
import openai
from PIL import Image
import numpy as np
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def parse_args():
"""
Parse command-line arguments.
"""
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument("--video_path", help="Path to the video files.", required=True)
parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True)
parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True)
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool")
parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
parser.add_argument("--image_aspect_ratio", type=str, default="anyres")
parser.add_argument("--image_grid_pinpoints", type=str, default="[(224, 448), (224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]")
parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad")
parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == 'true'), default=True)
parser.add_argument("--for_get_frames_num", type=int, default=4)
parser.add_argument("--load_8bit", type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--api_key", type=str, help="OpenAI API key")
parser.add_argument("--mm_newline_position", type=str, default="no_token")
parser.add_argument("--force_sample", type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument("--add_time_instruction", type=str, default=False)
return parser.parse_args()
def load_video(video_path,args):
if args.for_get_frames_num == 0:
return np.zeros((1, 336, 336, 3))
vr = VideoReader(video_path, ctx=cpu(0),num_threads=1)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, len(vr), fps)]
frame_time = [i/fps for i in frame_idx]
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
sample_fps = args.for_get_frames_num
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
# import pdb;pdb.set_trace()
return spare_frames,frame_time,video_time
def load_video_base64(path):
video = cv2.VideoCapture(path)
base64Frames = []
while video.isOpened():
success, frame = video.read()
if not success:
break
_, buffer = cv2.imencode(".jpg", frame)
base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
video.release()
# print(len(base64Frames), "frames read.")
return base64Frames
def run_inference(args):
"""
Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
Args:
args: Command-line arguments.
"""
# Initialize the model
if "gpt4v" != args.model_path:
model_name = get_model_name_from_path(args.model_path)
# Set model configuration parameters if they exist
if args.overwrite == True:
overwrite_config = {}
overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
overwrite_config["mm_newline_position"] = args.mm_newline_position
cfg_pretrained = AutoConfig.from_pretrained(args.model_path)
# import pdb;pdb.set_trace()
if "qwen" not in args.model_path.lower():
if "224" in cfg_pretrained.mm_vision_tower:
# suppose the length of text tokens is around 1000, from bo's report
least_token_number = args.for_get_frames_num*(16//args.mm_spatial_pool_stride)**2 + 1000
else:
least_token_number = args.for_get_frames_num*(24//args.mm_spatial_pool_stride)**2 + 1000
scaling_factor = math.ceil(least_token_number/4096)
if scaling_factor >= 2:
if "vicuna" in cfg_pretrained._name_or_path.lower():
print(float(scaling_factor))
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, load_8bit=args.load_8bit, overwrite_config=overwrite_config)
else:
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
else:
pass
# import pdb;pdb.set_trace()
if getattr(model.config, "force_sample", None) is not None:
args.force_sample = model.config.force_sample
else:
args.force_sample = False
# import pdb;pdb.set_trace()
if getattr(model.config, "add_time_instruction", None) is not None:
args.add_time_instruction = model.config.add_time_instruction
else:
args.add_time_instruction = False
# Create the output directory if it doesn't exist
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
output_name = args.output_name
answers_file = os.path.join(args.output_dir, f"{output_name}.json")
ans_file = open(answers_file, "w")
video_path = args.video_path
all_video_pathes = []
# Check if the video_path is a directory or a file
if os.path.isdir(video_path):
# If it's a directory, loop over all files in the directory
for filename in os.listdir(video_path):
# Load the video file
cur_video_path = os.path.join(video_path, f"{filename}")
all_video_pathes.append(os.path.join(video_path, cur_video_path))
else:
# If it's a file, just process the video
all_video_pathes.append(video_path)
# import pdb;pdb.set_trace()
for video_path in all_video_pathes:
sample_set = {}
question = args.prompt
sample_set["Q"] = question
sample_set["video_name"] = video_path
# Check if the video exists
if os.path.exists(video_path):
if "gpt4v" != args.model_path:
video,frame_time,video_time = load_video(video_path, args)
video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda()
video = [video]
else:
spare_frames,frame_time,video_time = load_video_base64(video_path)
interval = int(len(video) / args.for_get_frames_num)
# try:
# Run inference on the video and add the output to the list
if "gpt4v" != args.model_path:
qs = question
if args.add_time_instruction:
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
qs = f'{time_instruciton}\n{qs}'
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
if tokenizer.pad_token_id is None:
if "qwen" in tokenizer.name_or_path.lower():
print("Setting pad token to bos token for qwen model.")
tokenizer.pad_token_id = 151643
attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
cur_prompt = question
else:
prompt = question
system_error = ""
if "gpt4v" != args.model_path:
with torch.inference_mode():
# model.update_prompt([[cur_prompt]])
# import pdb;pdb.set_trace()
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
if "mistral" not in cfg_pretrained._name_or_path.lower():
output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=False, temperature=0.0, max_new_tokens=1024, top_p=0.1,num_beams=1,use_cache=True, stopping_criteria=[stopping_criteria])
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
else:
output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=False, temperature=0.0, max_new_tokens=1024, top_p=0.1, num_beams=1, use_cache=True)
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True)
else:
openai.api_key = args.api_key # Your API key here
max_num_retries = 0
retry = 5
PROMPT_MESSAGES = [
{
"role": "user",
"content": [
f"These are frames from a video that I want to upload. Answer me one question of this video: {prompt}",
*map(lambda x: {"image": x, "resize": 336}, video[0::interval]),
],
},
]
params = {
"model": "gpt-4-vision-preview", #gpt-4-1106-vision-preview
"messages": PROMPT_MESSAGES,
"max_tokens": 1024,
}
sucess_flag=False
while max_num_retries < retry:
try:
result = openai.ChatCompletion.create(**params)
outputs = result.choices[0].message.content
sucess_flag = True
break
except Exception as inst :
if 'error' in dir(inst):
# import pdb;pdb.set_trace()
if inst.error.code == 'rate_limit_exceeded':
if "TPM" in inst.error.message:
time.sleep(30)
continue
else:
import pdb;pdb.set_trace()
elif inst.error.code == 'insufficient_quota':
print(f'insufficient_quota key')
exit()
elif inst.error.code == 'content_policy_violation':
print(f'content_policy_violation')
system_error = "content_policy_violation"
break
print('Find error message in response: ',str(inst.error.message), 'error code: ', str(inst.error.code))
continue
if not sucess_flag:
print(f'Calling OpenAI failed after retrying for {max_num_retries} times. Check the logs for details.')
exit()
if "gpt4v" != args.model_path:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
else:
print(len(video[0::interval]))
print(f"Question: {prompt}\n")
print(f"Response: {outputs}\n")
if "gpt4v" == args.model_path:
if system_error == 'content_policy_violation':
continue
elif system_error == "":
continue
else:
import pdb;pdb.set_trace()
# import pdb;pdb.set_trace()
if "mistral" not in cfg_pretrained._name_or_path.lower():
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
sample_set["pred"] = outputs
ans_file.write(json.dumps(sample_set, ensure_ascii=False) + "\n")
ans_file.flush()
ans_file.close()
if __name__ == "__main__":
args = parse_args()
run_inference(args)
\ No newline at end of file
import json
from math import ceil
def split_json_file(input_file, n_splits):
# Read the JSON file
with open(input_file, "r") as file:
data = json.load(file)
# Calculate the size of each split
total_items = len(data)
items_per_split = ceil(total_items / n_splits)
# Split the data and save into separate files
for i in range(n_splits):
start_index = i * items_per_split
end_index = min((i + 1) * items_per_split, total_items)
split_data = data[start_index:end_index]
# Write the split data to a new JSON file
with open(f"{input_file.split('.')[0]}_split_{i}.json", "w") as split_file:
json.dump(split_data, split_file, indent=4)
def main():
import argparse
parser = argparse.ArgumentParser(description="Split a JSON file into multiple parts.")
parser.add_argument("--input_file", type=str, help="The JSON file to split")
parser.add_argument("--n_splits", type=int, help="The number of splits")
args = parser.parse_args()
split_json_file(args.input_file, args.n_splits)
if __name__ == "__main__":
main()
import os
import shutil
import glob
def remove_checkpoints(directory, pattern):
# Walk through the directory
for root, dirs, files in os.walk(directory):
# Use glob to find paths matching the pattern
for file_path in glob.glob(os.path.join(root, pattern)):
# Check if it is a directory
if "llava-1.6-mistral-7b" in file_path:
continue
if os.path.isdir(file_path):
# Remove the directory
print(f"Removing {file_path}")
input("Press Enter to continue...")
shutil.rmtree(file_path)
print(f"Removed directory: {file_path}")
else:
print(f"Removing {file_path}")
input("Press Enter to continue...")
# Remove the file
os.remove(file_path)
print(f"Removed file: {file_path}")
# Directory containing the checkpoints
directory = "/mnt/bn/vl-research/checkpoints/feng/"
# Pattern to match in the file names
pattern = "global_step*"
# Call the function
remove_checkpoints(directory, pattern)
import argparse
import json
import time
import os
import tqdm
import sglang as sgl
from sglang.test.test_utils import select_sglang_backend
from sglang.utils import dump_state_text
@sgl.function
def image_description(s, image_file):
prompt = "Please generate detailed descriptions of the given image."
s += sgl.user(sgl.image(image_file) + prompt)
s += sgl.assistant(sgl.gen("answer", max_tokens=1024, temperature=0.0))
def load_progress(progress_file):
print(f"Load progress from {progress_file}")
if os.path.exists(progress_file):
with open(progress_file, "r") as f:
return json.load(f)
return {"last_index": -1, "last_chunk": -1, "results": [], "annotations": []}
def save_progress(progress_file, progress_data):
with open(progress_file, "w") as f:
json.dump(progress_data, f, indent=2)
def find_images_in_subfolders(folder_path):
image_extensions = (".png", ".jpg", ".jpeg", ".gif", ".bmp")
image_files = []
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith(image_extensions):
image_files.append(os.path.join(root, file))
return image_files
def main(args):
dist_rank = args.dist
dist_size = args.total_dist
base_dir = os.path.dirname(args.result_file)
os.makedirs(base_dir, exist_ok=True) # Ensure the base directory exists
progress_file = f"{base_dir}/progress_{dist_rank}_or_{dist_size}.json"
progress_data = load_progress(progress_file)
with open(args.json_path, "r") as fp:
data = json.load(fp)
image_files = [os.path.join(args.images_root, item["image"]) for item in data]
image_files = image_files[: args.limit] if args.limit > 0 else image_files
# Shard the data
shard_size = len(image_files) // dist_size
start_index = shard_size * dist_rank
end_index = start_index + shard_size if dist_rank < dist_size - 1 else len(image_files)
shard_files = image_files[start_index:end_index]
print(f"Querying {len(shard_files)} images from index {start_index} to {end_index - 1}")
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
tic = time.time()
batch_size = args.parallel
for batch_start in tqdm.tqdm(range(0, len(shard_files), batch_size)):
batch_end = min(batch_start + batch_size, len(shard_files))
if batch_start <= progress_data.get("last_index", -1):
print(f"Skipping already processed batch starting at {batch_start}")
continue
batch_arguments = [{"image_file": image_file} for image_file in shard_files[batch_start:batch_end]]
try:
batch_states = image_description.run_batch(batch_arguments, temperature=0, num_threads=args.parallel, progress_bar=False)
for i, ret in enumerate(batch_states):
image_file = batch_arguments[i]["image_file"]
caption = ret.text().split("ASSISTANT:")[-1].strip()
progress_data["annotations"].append({"image_file": image_file, "caption": caption})
progress_data["last_index"] = batch_start + i # Update last_index relative to this rank's shard
save_progress(progress_file, progress_data)
except Exception as e:
print(f"Error during batch processing: {e}")
save_progress(progress_file, progress_data)
break
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
value = {
"task": "image_captioning",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(shard_files),
"parallel": args.parallel,
"results": progress_data["annotations"],
}
result_file = args.result_file.replace(".json", f"_shard_{dist_rank}_or_{dist_size}.json")
print(f"Write output to {result_file}")
with open(result_file, "w") as fout:
json.dump(value, fout, indent=2)
save_progress(progress_file, progress_data)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--images_root", type=str, default="/mnt/bn/vl-research/data/llava_data/cc3m")
parser.add_argument("--json_path", type=str, default="/mnt/bn/vl-research/data/llava_instruct/cc3m_recap_requery_363707.json")
parser.add_argument("--max_tokens", type=int, default=1024)
parser.add_argument("--parallel", type=int, default=32)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--result_file", type=str, default="/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/playground/sgl_llava_inference.json")
parser.add_argument("--limit", type=int, default=-1)
parser.add_argument("--dist", type=int, default=0, help="The rank of the distributed machine")
parser.add_argument("--total_dist", type=int, default=6, help="Total number of distributed machines")
args = parser.parse_args()
main(args)
from datasets import Dataset, Features, Value, ClassLabel, Sequence, Image
import json
import PIL.Image as pil_image
from io import BytesIO
from tqdm import tqdm
json_paths = [
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_metagen_87358.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_rule_geo_100000.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/k12_printing_train_256646.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/iiit5k_annotations_2000.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/hme100k_train_clean_74502.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_azuregpt_detailed_understanding_4874.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_vqa_4404.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_azuregpt4v_1992.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_chart_1787.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_normal_gpt4v_filtered_10500.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/scienceqa_nona_context_19218.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_vflan4v_20000.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_laion4v_50000.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textocr_gpt4v_train_converted_25114.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_train_internvl_single_12413.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textcaps_train_21952.json",
# "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_qa_sft.json",
# "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_cap_sft.json",
# "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_ie_sft.json",
# "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_kg_sft.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/vision_flan_filtered_186070.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo3k_2101.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-coco-50k.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-knowledge-2k.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-llava-30k.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-sam-20k.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_CLEVR-Math_5290.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_FigureQA_17597.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Geometry3K_9734.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GeoQA+_17172.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GEOS_508.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_IconQA_22599.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_MapQA_5235.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PMC-VQA_35958.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Super-CLEVR_8652.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TabMWP_22462.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_UniGeo_11959.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VizWiz_6614.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_qwen2_72b_st_300000_sp_token_fltd_299992.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_st_300000.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_mt_300000_sp_token_fltd_299998.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/image_textualization_dataset_filtered.json",
# "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/cambrian_filtered_gpt4vo_sp_token_fltd_max10k.json",
"/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4o_dataset.jsonl",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/ai2d_llava_format_2434.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/aokvqa_16539_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/chart2text_26961.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/chartqa_18265_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/clevr_70000_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/diagram_image_to_text_300.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/dvqa_200000_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/figureqa_100000_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/geomverse_9303.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/hateful_memes_8500_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/hitab_2500_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/iam_5663.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/raven_42000.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/iconqa_llava_format_27307.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/infographic_vqa_2118_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/intergps_1280_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/mapqa_37417_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/multihiertt_7619.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/rendered_text_10000.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/robut_sqa_8514.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wikisql_74989.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wtq_38246_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/screen2words_15730.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/scienceqa_llava_format_4976.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/tabmwp_22722.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/tallyqa_98680_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/st_vqa_17247_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/tqa_llava_format_27307.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/visual7w_llava_format_14366.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/visualmrc_3027.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/vqarad_313_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/vsr_2157_llava_format.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/vistext_9969.json",
"/mnt/bn/vl-research/data/llava_instruct/cauldron/websight_10000.json"
]
short_names = [
# "mavis_math_metagen",
# "mavis_math_rule_geo",
# "k12_printing",
# "iiit5k",
# "hme100k",
# "ai2d(gpt4v)",
# "infographic_vqa",
# "infographic(gpt4v)",
# "lrv_chart",
# "lrv_normal(filtered)",
# "scienceqa(nona_context)",
# "allava_instruct_vflan4v",
# "allava_instruct_laion4v",
# "textocr(gpt4v)",
# "ai2d(internvl)",
# "textcaps",
# "ureader_qa", # need to re-upload
# "ureader_cap", # need to re-upload
# "ureader_ie", # need to re-upload
# "ureader_kg", # need to re-upload
# "vision_flan(filtered)",
# "mathqa",
# "geo3k",
# "geo170k(qa)",
# "geo170k(align)",
# "sharegpt4v(coco)",
# "sharegpt4v(knowledge)",
# "sharegpt4v(llava)",
# "sharegpt4v(sam)",
# "CLEVR-Math(MathV360K)",
# "FigureQA(MathV360K)",
# "Geometry3K(MathV360K)",
# "GeoQA+(MathV360K)",
# "GEOS(MathV360K)",
# "IconQA(MathV360K)",
# "MapQA(MathV360K)",
# "PMC-VQA(MathV360K)",
# "Super-CLEVR(MathV360K)",
# "TabMWP(MathV360K)",
# "UniGeo(MathV360K)",
# "VizWiz(MathV360K)",
# "magpie_pro(qwen2_72b_st)",
# "magpie_pro(l3_80b_st)",
# "magpie_pro(l3_80b_mt)",
# "image_textualization(filtered)",
# "cambrian(filtered_gpt4vo)", # need to re-upload
"sharegpt4o",
"ai2d(cauldron,llava_format)",
"aokvqa(cauldron,llava_format)",
"chart2text(cauldron)",
"chartqa(cauldron,llava_format)",
"clevr(cauldron,llava_format)",
"diagram_image_to_text(cauldron)",
"dvqa(cauldron,llava_format)",
"figureqa(cauldron,llava_format)",
"geomverse(cauldron)",
"hateful_memes(cauldron,llava_format)",
"hitab(cauldron,llava_format)",
"iam(cauldron)",
"raven(cauldron)",
"iconqa(cauldron,llava_format)",
"infographic_vqa_llava_format",
"intergps(cauldron,llava_format)",
"mapqa(cauldron,llava_format)",
"multihiertt(cauldron)",
"rendered_text(cauldron)",
"robut_sqa(cauldron)",
"robut_wikisql(cauldron)",
"robut_wtq(cauldron,llava_format)",
"screen2words(cauldron)",
"scienceqa(cauldron,llava_format)",
"tabmwp(cauldron)",
"tallyqa(cauldron,llava_format)",
"st_vqa(cauldron,llava_format)",
"tqa(cauldron,llava_format)",
"visual7w(cauldron,llava_format)",
"visualmrc(cauldron)",
"vqarad(cauldron,llava_format)",
"vsr(cauldron,llava_format)",
"vistext(cauldron)",
"websight(cauldron)"
]
def upload_data(json_path, short_name):
def gen():
if json_path.endswith(".jsonl"):
with open(json_path, "r") as f:
data = [json.loads(line) for line in f]
else:
with open(json_path, "r") as f:
data = json.load(f)
preview_index = 5
idx = 0
for item in tqdm(data):
if preview_index > 0:
preview_index -= 1
print(item)
continue
try:
if "image" in item:
image_path = f"/mnt/bn/vl-research/data/llava_data/{item['image']}"
try:
with open(image_path, "rb") as img_file:
image = pil_image.open(BytesIO(img_file.read()))
except:
print(f"Failed to load image {item['image']}")
continue
else:
image = None
item_id = item["id"] if "id" in item else f"{idx:06d}"
yield {"id": item_id, "image": image, "conversations": item["conversations"], "data_source": short_name}
idx += 1
except Exception as e:
print(e)
continue
hf_dataset = Dataset.from_generator(generator=gen, num_proc=32)
hf_dataset.push_to_hub("lmms-lab/LLaVA-OneVision-Data", config_name=short_name, split="train")
for json_path, short_name in zip(json_paths, short_names):
upload_data(json_path, short_name)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment