Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from PIL import Image
import requests
import copy
import torch
import os
from pathlib import Path
current_dir = str(Path(__file__).resolve().parent)
# Load model
# pretrained = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
pretrained = os.path.join(current_dir, "ckpts", "llava-onevision-qwen2-0.5b-ov")
model_name = "llava_qwen"
device = "cuda"
device_map = "auto"
llava_model_args = {
"multimodal": True,
}
overwrite_config = {}
overwrite_config["image_aspect_ratio"] = "pad"
llava_model_args["overwrite_config"] = overwrite_config
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, **llava_model_args)
model.eval()
# Load two images
url1 = os.path.join(current_dir, "examples", "llava_v1_5_radar.jpg")
url2 = os.path.join(current_dir, "examples", "llava_logo.png")
# image1 = Image.open(requests.get(url1, stream=True).raw)
# image2 = Image.open(requests.get(url2, stream=True).raw)
image1 = Image.open(url1)
image2 = Image.open(url2)
images = [image1, image2]
image_tensors = process_images(images, image_processor, model.config)
image_tensors = [_image.to(dtype=torch.float16, device=device) for _image in image_tensors]
# Prepare interleaved text-image input
conv_template = "qwen_1_5"
question = f"{DEFAULT_IMAGE_TOKEN} This is the first image. Can you describe what you see?\n\nNow, let's look at another image: {DEFAULT_IMAGE_TOKEN}\nWhat's the difference between these two images?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size for image in images]
# Generate response
cont = model.generate(
input_ids,
images=image_tensors,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs[0])
\ No newline at end of file
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from PIL import Image
import requests
import copy
import torch
import sys
import warnings
import os
from pathlib import Path
current_dir = str(Path(__file__).resolve().parent)
warnings.filterwarnings("ignore")
# pretrained = "lmms-lab/llava-onevision-qwen2-0.5b-si"
pretrained = os.path.join(current_dir, "ckpts", "llava-onevision-qwen2-0.5b-si")
model_name = "llava_qwen"
device = "cuda"
device_map = "auto"
llava_model_args = {
"multimodal": True,
"attn_implementation": "sdpa",
}
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, **llava_model_args) # Add any other thing you want to pass in llava_model_args
model.eval()
# url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
# image = Image.open(requests.get(url, stream=True).raw)
image = Image.open("./examples/llava_v1_5_radar.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
conv_template = "qwen_1_5" # Make sure you use correct chat template for different models
question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]
cont = model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)
\ No newline at end of file
from operator import attrgetter
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
import torch
import cv2
import numpy as np
from PIL import Image
import requests
import copy
import warnings
from decord import VideoReader, cpu
import os
from pathlib import Path
current_dir = str(Path(__file__).resolve().parent)
warnings.filterwarnings("ignore")
# Load the OneVision model
# pretrained = "lmms-lab/llava-onevision-qwen2-7b-ov"
pretrained = os.path.join(current_dir, "ckpts", "llava-onevision-qwen2-7b-ov")
model_name = "llava_qwen"
device = "cuda"
device_map = "auto"
llava_model_args = {
"multimodal": True,
}
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa", **llava_model_args)
model.eval()
# Function to extract frames from video
def load_video(video_path, max_frames_num):
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames # (frames, height, width, channels)
# Load and process video
video_path = "./jobs.mp4"
video_frames = load_video(video_path, 16)
print(video_frames.shape) # (16, 1024, 576, 3)
image_tensors = []
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()
image_tensors.append(frames)
# Prepare conversation input
conv_template = "qwen_1_5"
question = f"{DEFAULT_IMAGE_TOKEN}\nDescribe what's happening in this video."
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [frame.size for frame in video_frames]
# Generate response
cont = model.generate(
input_ids,
images=image_tensors,
image_sizes=image_sizes,
do_sample=False,
temperature=0,
max_new_tokens=4096,
modalities=["video"],
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs[0])
\ No newline at end of file
import json
import os
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from multiprocessing import Pool
import functools
import argparse
def load_data(json_path):
with open(json_path, "r") as f:
return json.load(f)
def filter_data(data):
filtered_data = [item for item in data if "image" in item]
return filtered_data
def calculate_image_dimension(image_path, images_folder):
full_path = os.path.join(images_folder, image_path)
try:
with Image.open(full_path) as img:
width, height = img.size
return width, height
except Exception as e:
print(f"Error opening {full_path}: {e}")
return None, None
def calculate_image_dimensions_multiprocess(filtered_data, images_folder, num_processes=256):
image_paths = []
for item in filtered_data:
if isinstance(item["image"], list):
image_paths.extend(item["image"])
else:
image_paths.append(item["image"])
with Pool(num_processes) as p:
dimensions = list(
tqdm(
p.imap(functools.partial(calculate_image_dimension, images_folder=images_folder), image_paths),
total=len(image_paths),
desc="Calculating image dimensions",
)
)
widths, heights = zip(*[dim for dim in dimensions if dim[0] is not None])
return list(widths), list(heights)
def tokenize(text):
return text.split()
def calculate_tokenized_lengths(data):
lengths = []
for item in tqdm(data, desc="Tokenizing conversations"):
for conversation in item["conversations"]:
tokenized_value = tokenize(conversation["value"])
lengths.append(len(tokenized_value))
return lengths
def main():
parser = argparse.ArgumentParser(description="Process data for LLaVA_Next project.")
parser.add_argument(
"--json_path",
type=str,
help="Path to the JSON file containing data.",
default="/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_DEMON-FULL.json",
)
parser.add_argument(
"--images_folder",
type=str,
default="/mnt/bn/vl-research/data/llava_data",
help="Path to the folder containing images.",
)
args = parser.parse_args()
llava_instruct_name = os.path.basename(args.json_path).replace(".json", "")
images_folder = args.images_folder
data = load_data(args.json_path)
filtered_data = filter_data(data)
print(f"Total data items: {len(data)}, Filtered data items: {len(filtered_data)}")
widths, heights = calculate_image_dimensions_multiprocess(filtered_data, images_folder)
max_width, max_height = max(widths), max(heights)
print(f"Max width: {max_width}, Max height: {max_height}")
tokenized_lengths = calculate_tokenized_lengths(filtered_data)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
# Plot 2D histogram
widths_bins = [min(widths), max(widths) + 1] if min(widths) == max(widths) else np.arange(min(widths), max(widths) + 100, 100)
heights_bins = [min(heights), max(heights) + 1] if min(heights) == max(heights) else np.arange(min(heights), max(heights) + 100, 100)
h, xedges, yedges, image = ax1.hist2d(widths, heights, bins=[widths_bins, heights_bins], cmap=plt.cm.jet, density=True)
fig.colorbar(image, ax=ax1)
ax1.set_xlabel("Width")
ax1.set_ylabel("Height")
ax1.set_title(
f"dist_{llava_instruct_name}_2d_w_h\nMax width: {max(widths)}, Max height: {max(heights)}",
fontsize=10,
)
# Plot histogram
hist, bin_edges = np.histogram(tokenized_lengths, bins=np.arange(0, max(tokenized_lengths) + 10, 10))
bins = np.arange(0, max(tokenized_lengths) + 10, 10)
ax2.bar(bin_edges[:-1], hist, width=7, edgecolor="black", log=True)
# Display every nth label on the x-axis
n = 8 # Adjust this value to control the number of labels displayed
ticks = bins[::n]
tick_labels = [int(tick) for tick in ticks]
ax2.set_xticks(ticks)
ax2.set_xticklabels(tick_labels, rotation=90, fontsize=8)
ax2.set_xlim(min(bin_edges), max(bin_edges))
ax2.set_xlabel("Tokenized Length")
ax2.set_ylabel("Count (log scale)")
ax2.set_title(f"dist_{llava_instruct_name}_tokenized_length", fontsize=8)
plt.tight_layout()
plt.savefig(f"./dist_{llava_instruct_name}_combined.png")
if __name__ == "__main__":
main()
import json
import os
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import yaml
class DataProcessor:
def __init__(self, file_path, image_root, video_root):
self.file_path = file_path
self.image_root = image_root
self.data = None
self.video_root = video_root
self.load_data()
def load_data(self):
if self.file_path.endswith(".json"):
with open(self.file_path, "r") as f:
self.data = json.load(f)
elif self.file_path.endswith(".yaml"):
with open(self.file_path, "r") as f:
self.data = yaml.safe_load(f)
elif self.file_path.endswith(".jsonl"):
with open(self.file_path, "r") as f:
self.data = [json.loads(line) for line in f.readlines()]
else:
raise ValueError("Unsupported file format")
def load_json_data(self, json_path):
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
return cur_data_dict
elif json_path.endswith(".json"):
with open(json_path, "r") as f:
return json.load(f)
else:
raise ValueError("Unsupported file format")
def check_image_existence(self, data):
if "image" in data:
if type(data["image"]) == list:
images = data["image"]
else:
images = [data["image"]]
for image in images:
full_image_path = os.path.join(self.image_root, image)
if not os.path.exists(full_image_path):
print(f"WARNING!!! {full_image_path} not exists !!!")
if "video" in data:
full_video_path = os.path.join(self.video_root, data["video"])
if not os.path.exists(full_video_path):
print(f"WARNING!!! {full_video_path} not exists !!!")
# if data["conversations"][0]["value"].count("<image>") > 1:
# print(f"WARNING!!! {data['conversations'][0]['value']} has more than one <image> !!!")
def check_item_structure(self, item):
if not all(key in item for key in ["conversations"]):
print(f"WARNING!!! Item {item.get('id', 'unknown')} is missing required fields!")
return False
conversations = item["conversations"]
if not isinstance(conversations, list) or len(conversations) < 2 or len(conversations) % 2 != 0:
print(f"WARNING!!! Item {item['id']} has invalid conversations structure!")
return False
for i, conv in enumerate(conversations):
if not all(key in conv for key in ["from", "value"]):
print(f"WARNING!!! Item {item['id']} has invalid conversation format!")
return False
expected_from = "human" if i % 2 == 0 else "gpt"
if conv["from"] != expected_from:
print(f"WARNING!!! Item {item['id']} has incorrect conversation order!")
return False
return True
def check_image_and_structure(self, item):
if not self.check_item_structure(item):
return
# self.check_image_existence(item)
def process_images(self):
if isinstance(self.data, list):
args = [d for d in self.data]
with Pool(processes=cpu_count()) as pool:
list(tqdm(pool.imap(self.check_image_and_structure, args), total=len(self.data)))
elif isinstance(self.data, dict):
for d in self.data["datasets"]:
dd_json_path = d["json_path"]
data = self.load_json_data(dd_json_path)
args = [d for d in data]
with Pool(processes=cpu_count()) as pool:
list(tqdm(pool.imap(self.check_image_and_structure, args), total=len(data), desc=f"Processing {dd_json_path}"))
def count_items(self):
if isinstance(self.data, list): # Assuming JSON data loaded directly
return len(self.data)
elif isinstance(self.data, dict): # Assuming YAML data loaded
total_items_count = 0
for d in self.data["datasets"]:
dd_json_path = d["json_path"]
data = self.load_json_data(dd_json_path)
current_items_count = len(data)
sampling_strategy = d["sampling_strategy"]
try:
if sampling_strategy != "all":
percentage = float(sampling_strategy.split(":")[-1].replace("%", "")) / 100.0
else:
percentage = 1.0
except Exception as e:
print(f"Error: {e}")
percentage = 1.0
sampling_count = int(current_items_count * percentage)
total_items_count += sampling_count
print(f"{dd_json_path}: {sampling_count}")
return total_items_count
def stat_data(self):
if isinstance(self.data, dict):
cur_lens_list = []
single_image_count = 0
multiple_image_count = 0
video_count = 0
total_count = 0
text_count = 0
max_tokens_item = None
max_tokens = 0
for d in self.data["datasets"]:
dd_json_path = d["json_path"]
data = self.load_json_data(dd_json_path)
sampling_strategy = d["sampling_strategy"]
try:
if sampling_strategy != "all":
percentage = float(sampling_strategy.split(":")[-1].replace("%", "")) / 100.0
else:
percentage = 1.0
except Exception as e:
print(f"Error parsing sampling strategy: {e}")
percentage = 1.0
sampled_count = int(len(data) * percentage)
print(f"{dd_json_path}: {sampled_count} (sampled from {len(data)})")
for item in data[:sampled_count]:
conversations = item["conversations"]
cur_len = sum([len(conv["value"].split()) for conv in conversations])
cur_lens_list.append(cur_len)
if cur_len > max_tokens:
max_tokens = cur_len
max_tokens_item = item
total_count += 1
if "image" in item:
if isinstance(item["image"], list):
if len(item["image"]) > 1:
multiple_image_count += 1
else:
single_image_count += 1
else:
single_image_count += 1
elif "video" in item:
video_count += 1
else:
text_count += 1
print(f"Max length: {max(cur_lens_list)}, Min length: {min(cur_lens_list)}, Average length: {sum(cur_lens_list) / len(cur_lens_list)}")
print(f"Total items: {total_count}")
print(f"Text items: {text_count} ({text_count/total_count*100:.2f}%)")
print(f"Single image items: {single_image_count} ({single_image_count/total_count*100:.2f}%)")
print(f"Multiple image items: {multiple_image_count} ({multiple_image_count/total_count*100:.2f}%)")
print(f"Video items: {video_count} ({video_count/total_count*100:.2f}%)")
print("\nItem with the largest number of tokens:")
print(f"Token count: {max_tokens}")
print("Item content:")
print(json.dumps(max_tokens_item, indent=2))
def filter_data(self):
if isinstance(self.data, dict):
for d in self.data["datasets"]:
dd_json_path = d["json_path"]
print(f"Processing {dd_json_path}")
data = self.load_json_data(dd_json_path)
filtered_data = []
mismatch_data = []
mismatch_flag = False
for item in data:
try:
if "image" in item:
num_image = len(item["image"]) if isinstance(item["image"], list) else 1
else:
num_image = 0
if "video" in item:
num_video = len(item["video"]) if isinstance(item["video"], list) else 1
else:
num_video = 0
num_visuals = num_image + num_video
conv_text = ""
for conv in item["conversations"]:
conv_text += conv["value"]
num_img_token_appearance = conv_text.count("<image>")
if len(conv_text) == 0:
print(f"Conversation text is empty for {item}")
if num_img_token_appearance == num_visuals or num_img_token_appearance < num_visuals and len(conv_text) > 0:
filtered_data.append(item)
elif num_img_token_appearance > num_visuals:
item["num_img_token_appearance"] = num_img_token_appearance
item["num_visuals"] = num_visuals
mismatch_data.append(item)
if not mismatch_flag:
print(f"Data mismatch for {item}")
mismatch_flag = True
except Exception as e:
print(f"Error: {e}")
print()
if mismatch_flag:
print(f"Data mismatch for {dd_json_path}")
if len(filtered_data) < len(data):
saving_dd_json_path = dd_json_path.replace(".jsonl", f"fltd_{len(filtered_data)}.json").replace(".json", f"fltd_{len(filtered_data)}.json")
with open(saving_dd_json_path, "w") as f:
json.dump(filtered_data, f, indent=2)
print(f"Filtered data count: {len(filtered_data)}")
else:
pass
def stat_and_filter_data(self, threshold):
if isinstance(self.data, dict):
cur_lens_list = []
single_image_count = 0
multiple_image_count = 0
video_count = 0
total_count = 0
text_count = 0
for d in self.data["datasets"]:
dd_json_path = d["json_path"]
data = self.load_json_data(dd_json_path)
sampling_strategy = d["sampling_strategy"]
filtered_data = []
try:
if sampling_strategy != "all":
percentage = float(sampling_strategy.split(":")[-1].replace("%", "")) / 100.0
else:
percentage = 1.0
except Exception as e:
print(f"Error parsing sampling strategy: {e}")
percentage = 1.0
sampled_count = int(len(data) * percentage)
print(f"{dd_json_path}: {sampled_count} (sampled from {len(data)})")
save_flag = False
for item in data:
total_count += 1
conversations = item["conversations"]
filtered_conversations = []
current_token_count = 0
for i in range(0, len(conversations), 2):
if i + 1 < len(conversations):
human_conv = conversations[i]
gpt_conv = conversations[i + 1]
pair_tokens = len(human_conv["value"].split()) + len(gpt_conv["value"].split())
if current_token_count + pair_tokens <= threshold:
filtered_conversations.extend([human_conv, gpt_conv])
current_token_count += pair_tokens
else:
save_flag = True
break
if filtered_conversations:
item["conversations"] = filtered_conversations
cur_len = sum([len(conv["value"].split()) for conv in filtered_conversations])
cur_lens_list.append(cur_len)
filtered_data.append(item)
if "image" in item:
if isinstance(item["image"], list):
if len(item["image"]) > 1:
multiple_image_count += 1
else:
single_image_count += 1
else:
single_image_count += 1
elif "video" in item:
video_count += 1
else:
text_count += 1
# Save filtered data for each dataset
if filtered_data and save_flag:
if dd_json_path.endswith(".jsonl"):
output_file = dd_json_path.replace(".jsonl", f"_filtered_{threshold}tokens_{len(filtered_data)}.jsonl")
with open(output_file, "w") as f:
for item in filtered_data:
f.write(json.dumps(item) + "\n")
else:
output_file = dd_json_path.replace(".json", f"_filtered_{threshold}tokens_{len(filtered_data)}.json")
with open(output_file, "w") as f:
json.dump(filtered_data, f, indent=2)
print(f"Filtered data for {dd_json_path} saved to: {output_file}")
print(f"Max length: {max(cur_lens_list)}, Min length: {min(cur_lens_list)}, Average length: {sum(cur_lens_list) / len(cur_lens_list)}")
print(f"Total items: {total_count}")
print(f"Text items: {text_count} ({text_count/total_count*100:.2f}%)")
print(f"Single image items: {single_image_count} ({single_image_count/total_count*100:.2f}%)")
print(f"Multiple image items: {multiple_image_count} ({multiple_image_count/total_count*100:.2f}%)")
print(f"Video items: {video_count} ({video_count/total_count*100:.2f}%)")
def main(file_path, image_root, operation, video_root, threshold=None):
processor = DataProcessor(file_path, image_root, video_root)
if operation == "check":
processor.process_images()
elif operation == "count":
total_items = processor.count_items()
print(f"Total items: {total_items}")
elif operation == "filter":
processor.filter_data()
elif operation == "stat":
processor.stat_data()
elif operation == "stat_and_filter":
if threshold is None:
raise ValueError("Threshold must be provided for stat_and_filter operation")
processor.stat_and_filter_data(threshold)
else:
raise ValueError("Unsupported operation")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file_path", type=str, default="/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_continual.yaml")
parser.add_argument("--image_root", type=str, default="/mnt/bn/vl-research/data/llava_data")
parser.add_argument("--video_root", type=str, default="/mnt/bn/vl-research/data/llava_video")
parser.add_argument("--operation", type=str, default="filter")
parser.add_argument("--threshold", type=int, default=None, help="Threshold for stat_and_filter operation")
args = parser.parse_args()
main(args.file_path, args.image_root, args.operation, args.video_root, args.threshold)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment