Commit 81028572 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1722 canceled with stages
from ...smp import *
from .multiple_choice import extract_answer_from_item
from PIL import Image, ImageOps
import torchvision
import random
import numbers
import math
import torch
def get_dimension_rating(data_path):
data = load(data_path)
result_board = {}
for idx, item in data.iterrows():
if item['task_type'] not in result_board:
result_board[item['task_type']] = [0, 0]
result_board[item['task_type']][1] += 1
if item['score']:
result_board[item['task_type']][0] += 1
correct = 0
total = 0
for key, value in result_board.items():
correct += value[0]
total += value[1]
result_board[key].append(f'{value[0] / value[1] * 100 :.2f}%')
result_board['overall'] = [correct, total, f'{correct / total * 100 :.2f}%']
return result_board
def check_ans(pred, gt):
flag = False
pred_list = pred.lower().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
return flag
def check_ans_with_model(pred, gt, model, item, dataset_name='MVBench'):
flag = False
pred_list = pred.lower().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
elif extract_answer_from_item(model, item, dataset_name)['opt'] == item['answer']:
flag = True
return flag
def check_ans_advanced(pred, gt):
number_table = {
0: 'zero',
1: 'one',
2: 'two',
3: 'three',
4: 'four',
5: 'five',
6: 'six',
7: 'seven',
8: 'eight',
9: 'nine',
}
flag = False
pred_list = pred.lower().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
try:
gt_content = number_table[int(gt_content.strip('. \n'))]
print(gt_content)
except:
pass
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
elif gt_content.lower().strip('. \n') in pred.lower().strip('. \n'):
flag = True
return flag
class GroupRandomCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img_group):
w, h = img_group[0].size
th, tw = self.size
out_images = list()
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in img_group:
assert (img.size[0] == w and img.size[1] == h)
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
class MultiGroupRandomCrop(object):
def __init__(self, size, groups=1):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.groups = groups
def __call__(self, img_group):
w, h = img_group[0].size
th, tw = self.size
out_images = list()
for i in range(self.groups):
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in img_group:
assert (img.size[0] == w and img.size[1] == h)
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
class GroupCenterCrop(object):
def __init__(self, size):
self.worker = torchvision.transforms.CenterCrop(size)
def __call__(self, img_group):
return [self.worker(img) for img in img_group]
class GroupRandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __init__(self, is_flow=False):
self.is_flow = is_flow
def __call__(self, img_group, is_flow=False):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
if self.is_flow:
for i in range(0, len(ret), 2):
# invert flow pixel values when flipping
ret[i] = ImageOps.invert(ret[i])
return ret
else:
return img_group
class GroupNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
rep_std = self.std * (tensor.size()[0] // len(self.std))
# TODO: make efficient
for t, m, s in zip(tensor, rep_mean, rep_std):
t.sub_(m).div_(s)
return tensor
class GroupScale(object):
""" Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.worker = torchvision.transforms.Resize(size, interpolation)
def __call__(self, img_group):
return [self.worker(img) for img in img_group]
class GroupOverSample(object):
def __init__(self, crop_size, scale_size=None, flip=True):
self.crop_size = crop_size if not isinstance(
crop_size, int) else (crop_size, crop_size)
if scale_size is not None:
self.scale_worker = GroupScale(scale_size)
else:
self.scale_worker = None
self.flip = flip
def __call__(self, img_group):
if self.scale_worker is not None:
img_group = self.scale_worker(img_group)
image_w, image_h = img_group[0].size
crop_w, crop_h = self.crop_size
offsets = GroupMultiScaleCrop.fill_fix_offset(
False, image_w, image_h, crop_w, crop_h)
oversample_group = list()
for o_w, o_h in offsets:
normal_group = list()
flip_group = list()
for i, img in enumerate(img_group):
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
normal_group.append(crop)
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
if img.mode == 'L' and i % 2 == 0:
flip_group.append(ImageOps.invert(flip_crop))
else:
flip_group.append(flip_crop)
oversample_group.extend(normal_group)
if self.flip:
oversample_group.extend(flip_group)
return oversample_group
class GroupFullResSample(object):
def __init__(self, crop_size, scale_size=None, flip=True):
self.crop_size = crop_size if not isinstance(
crop_size, int) else (crop_size, crop_size)
if scale_size is not None:
self.scale_worker = GroupScale(scale_size)
else:
self.scale_worker = None
self.flip = flip
def __call__(self, img_group):
if self.scale_worker is not None:
img_group = self.scale_worker(img_group)
image_w, image_h = img_group[0].size
crop_w, crop_h = self.crop_size
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4
offsets = list()
offsets.append((0 * w_step, 2 * h_step)) # left
offsets.append((4 * w_step, 2 * h_step)) # right
offsets.append((2 * w_step, 2 * h_step)) # center
oversample_group = list()
for o_w, o_h in offsets:
normal_group = list()
flip_group = list()
for i, img in enumerate(img_group):
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
normal_group.append(crop)
if self.flip:
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
if img.mode == 'L' and i % 2 == 0:
flip_group.append(ImageOps.invert(flip_crop))
else:
flip_group.append(flip_crop)
oversample_group.extend(normal_group)
oversample_group.extend(flip_group)
return oversample_group
class GroupMultiScaleCrop(object):
def __init__(self, input_size, scales=None, max_distort=1,
fix_crop=True, more_fix_crop=True):
self.scales = scales if scales is not None else [1, .875, .75, .66]
self.max_distort = max_distort
self.fix_crop = fix_crop
self.more_fix_crop = more_fix_crop
self.input_size = input_size if not isinstance(input_size, int) else [
input_size, input_size]
self.interpolation = Image.BILINEAR
def __call__(self, img_group):
im_size = img_group[0].size
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
crop_img_group = [
img.crop(
(offset_w,
offset_h,
offset_w + crop_w,
offset_h + crop_h)) for img in img_group]
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
for img in crop_img_group]
return ret_img_group
def _sample_crop_size(self, im_size):
image_w, image_h = im_size[0], im_size[1]
# find a crop size
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in self.scales]
crop_h = [
self.input_size[1] if abs(
x - self.input_size[1]) < 3 else x for x in crop_sizes]
crop_w = [
self.input_size[0] if abs(
x - self.input_size[0]) < 3 else x for x in crop_sizes]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= self.max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not self.fix_crop:
w_offset = random.randint(0, image_w - crop_pair[0])
h_offset = random.randint(0, image_h - crop_pair[1])
else:
w_offset, h_offset = self._sample_fix_offset(
image_w, image_h, crop_pair[0], crop_pair[1])
return crop_pair[0], crop_pair[1], w_offset, h_offset
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
offsets = self.fill_fix_offset(
self.more_fix_crop, image_w, image_h, crop_w, crop_h)
return random.choice(offsets)
@staticmethod
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4
ret = list()
ret.append((0, 0)) # upper left
ret.append((4 * w_step, 0)) # upper right
ret.append((0, 4 * h_step)) # lower left
ret.append((4 * w_step, 4 * h_step)) # lower right
ret.append((2 * w_step, 2 * h_step)) # center
if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
return ret
class GroupRandomSizedCrop(object):
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
This is popularly used to train the Inception networks
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, img_group):
for attempt in range(10):
area = img_group[0].size[0] * img_group[0].size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
x1 = random.randint(0, img_group[0].size[0] - w)
y1 = random.randint(0, img_group[0].size[1] - h)
found = True
break
else:
found = False
x1 = 0
y1 = 0
if found:
out_group = list()
for img in img_group:
img = img.crop((x1, y1, x1 + w, y1 + h))
assert (img.size == (w, h))
out_group.append(
img.resize(
(self.size, self.size), self.interpolation))
return out_group
else:
# Fallback
scale = GroupScale(self.size, interpolation=self.interpolation)
crop = GroupRandomCrop(self.size)
return crop(scale(img_group))
class ConvertDataFormat(object):
def __init__(self, model_type):
self.model_type = model_type
def __call__(self, images):
if self.model_type == '2D':
return images
tc, h, w = images.size()
t = tc // 3
images = images.view(t, 3, h, w)
images = images.permute(1, 0, 2, 3)
return images
class Stack(object):
def __init__(self, roll=False):
self.roll = roll
def __call__(self, img_group):
if img_group[0].mode == 'L':
return np.concatenate([np.expand_dims(x, 2)
for x in img_group], axis=2)
elif img_group[0].mode == 'RGB':
if self.roll:
return np.concatenate([np.array(x)[:, :, ::-1]
for x in img_group], axis=2)
else:
# print(np.concatenate(img_group, axis=2).shape)
# print(img_group[0].shape)
return np.concatenate(img_group, axis=2)
class ToTorchFormatTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
def __init__(self, div=True):
self.div = div
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
else:
# handle PIL Image
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(
pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255) if self.div else img.float()
class IdentityTransform(object):
def __call__(self, data):
return data
from ...smp import *
def OCRBench_eval(eval_file):
OCRBench_score = {
'Regular Text Recognition': 0,
'Irregular Text Recognition': 0,
'Artistic Text Recognition': 0,
'Handwriting Recognition': 0,
'Digit String Recognition': 0,
'Non-Semantic Text Recognition': 0,
'Scene Text-centric VQA': 0,
'Doc-oriented VQA': 0,
'Key Information Extraction': 0,
'Handwritten Mathematical Expression Recognition': 0
}
logger = get_logger('Evaluation')
data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
for i in tqdm(range(len(lines))):
line = lines[i]
predict = str(line['prediction'])
answers = eval(line['answer'])
category = line['category']
if category == 'Handwritten Mathematical Expression Recognition':
for j in range(len(answers)):
answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
predict = predict.strip().replace('\n', ' ').replace(' ', '')
if answer in predict:
OCRBench_score[category] += 1
break
else:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace('\n', ' ')
predict = predict.lower().strip().replace('\n', ' ')
if answer in predict:
OCRBench_score[category] += 1
break
final_score_dict = {}
final_score_dict['Text Recognition'] = (
OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
+ OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
+ OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
)
final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
final_score_dict['Handwritten Mathematical Expression Recognition'] = \
OCRBench_score['Handwritten Mathematical Expression Recognition']
final_score_dict['Final Score'] = (
final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
+ final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
+ final_score_dict['Handwritten Mathematical Expression Recognition']
)
final_score_dict['Final Score Norm'] = float(final_score_dict['Final Score']) / 10
score_pth = eval_file.replace('.xlsx', '_score.json')
dump(final_score_dict, score_pth)
logger.info(f'OCRBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
logger.info('Score: ')
for key, value in final_score_dict.items():
logger.info('{}:{}'.format(key, value))
"""
Copied from https://github.com/allenai/allennlp-semparse
Modified from https://github.com/naver-ai/tablevqabench
"""
import re
import unicodedata
import time
from abc import ABCMeta, abstractmethod
from math import isinf, isnan
# Vision Prompts
VWTQ_PROMPT = (
'You are asked to answer questions asked on an image.\n'
'You should answer the question with a single word.\n'
'Example: \n'
'Question: what was the only year mr. wu competed in the olympic games?\n'
'Answer: 2004\n'
'Question: which township in pope county, arkansas has the least amount of water area?\n'
'Answer: Freeman\n'
'If you have multiple answers, please separate them with || marks. Example: Apple||Banana||Tomato\n\n'
'Question: {question}\n'
'Answer:'
)
VTABFACT_PROMPT = (
'You are asked to answer whether the statement is True or False based on given image\n'
'You should only answer True or False.\n'
'Example: \n'
'Statement: the milwaukee buck win 6 game in the 2010 - 11 season\n'
'Answer: True\n'
'Statement: only the top team score above the average of 8.8\n'
'Answer: False\n\n'
'Statement: {question}\n'
'Answer:'
)
FINTABNETQA_PROMPT = (
'You are asked to answer questions asked on a image.\n'
'You should answer the question within a single word or few words.\n'
'If units can be known, the answer should include units such as $, %, million and etc.\n'
'Example: \n'
'Question: What were the total financing originations for the fiscal year ended October 31, 2004?\n'
'Answer: $3,852 million\n'
'Question: What is the time period represented in the table?\n'
'Answer: October 31\n'
'Question: What was the percentage of net sales for selling, general and administrative expenses in 2006?\n'
'Answer: 34.2%\n'
'Question: {question}\n'
'Answer:'
)
def evaluate_tabfact(data, score_keys):
num_examples = 0
num_correct = 0
manual_check = 0
start_time = time.time()
for instance in data:
if instance['prediction'] is None:
instance['prediction'] = 'none'
pred = instance['prediction'].lower()
gt = instance['answer']
num_examples += 1
if 'true' in pred and 'false' in pred:
manual_check += 1
score = None
elif 'true' in pred and gt == '1':
num_correct += 1
score = 1
elif 'false' in pred and gt == '0':
num_correct += 1
score = 1
else:
score = 0
instance['scores'] = {score_keys[0]: score}
if manual_check > 0:
print(f'the number of not properly parsed samples: {manual_check}')
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_wtq(data, score_keys):
num_examples = 0
num_correct = 0
start_time = time.time()
for instance in data:
pred = instance['prediction'].replace('||', '|')
gt = instance['answer']
original_strings = tsv_unescape_list(gt)
target_values = to_value_list(original_strings)
predicted_strings = tsv_unescape_list(pred)
predicted_values = to_value_list(predicted_strings)
correct = check_denotation(target_values, predicted_values)
num_examples += 1
score = 0
if correct:
num_correct += 1
score = 1
instance['scores'] = {score_keys[0]: score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_fintabnet(data, score_keys):
num_examples = 0
num_correct, _num_correct = 0, 0
start_time = time.time()
for instance in data:
pred, preds = fintabnet_normalize(instance['prediction'])
gt, gts = fintabnet_normalize(instance['answer'])
correct = 1 if gt == pred else 0
_correct = any(_pred == _gt for _pred in preds for _gt in gts)
num_examples += 1
score, _score = 0, 0
if correct:
num_correct += 1
score = 1
if _correct:
_num_correct += 1
_score = 1
instance['scores'] = {score_keys[0]: _score, 'exact_score': score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
_Accuracy = round((_num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': ['relieved_accuracy', score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [_Accuracy, Accuracy],
}
return meta
def fintabnet_normalize(s):
s = normalize(s)
remove_words = [
'dollar', 'gallons', 'square feet', 'shares', 'mbtu',
'mbpd', 'mbbls', 'mmbtu', 'unit', 'gwh', 'year', 'mmcf', 'mile', 'mboe'
]
# Data specific filtering using regular expressions
# Remove special characters like $, (, and )
s = re.sub(r'[\$\(\),]', '', s)
# Replace "dollar" with empty string if it's not part of another word
pattern = r'\b(' + '|'.join(remove_words) + r')s?\b'
s = re.sub(pattern, '', s, flags=re.IGNORECASE)
# Unit conversion dictionary with regex patterns for flexibility
unit_conversion = {
r' \bthousand\b': 'e3',
r' \bmillion\b': 'e6',
r' \bbillion\b': 'e9',
r'\bthousand\b': 'e3',
r'\bmillion\b': 'e6',
r'\bbillion\b': 'e9',
r' ?%': 'e-2',
}
# Convert percentages to their decimal representation.
# Applying this after unit_conversion prevents "percent" from being processed
# in cases like "million %", which would be incorrect.
# s = re.sub(r' ?%', 'e-2', s)
# s_percent = re.sub(r' ?%', '', s_percent)
s_unit_free = s
# Iterate over unit_conversion and apply transformations
for pattern, value in unit_conversion.items():
s = re.sub(pattern, value, s)
s_unit_free = re.sub(pattern, '', s_unit_free)
# Attempt to convert to float
try:
return float(s), [float(s), float(s_unit_free)]
except ValueError:
# Return the original string and the error for debugging purposes
return s, [s, s_unit_free]
def normalize(x):
if not isinstance(x, str):
x = x.decode('utf8', errors='ignore')
# Remove diacritics
x = ''.join(
c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn'
)
# Normalize quotes and dashes
x = re.sub(r'[‘’´`]', "'", x)
x = re.sub(r'[“”]', '"', x)
x = re.sub(r'[‐‑‒–—−]', '-', x)
while True:
old_x = x
# Remove citations
x = re.sub(r'((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$', '', x.strip())
# Remove details in parenthesis
x = re.sub(r'(?<!^)( \([^)]*\))*$', '', x.strip())
# Remove outermost quotation mark
x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
if x == old_x:
break
# Remove final '.'
if x and x[-1] == '.':
x = x[:-1]
# Collapse whitespaces and convert to lower case
x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
return x
# Value Types
class Value(object):
__metaclass__ = ABCMeta
# Should be populated with the normalized string
_normalized = None
@abstractmethod
def match(self, other):
"""Return True if the value matches the other value.
Args:
other (Value)
Returns:
a boolean
"""
pass
@property
def normalized(self):
return self._normalized
class StringValue(Value):
def __init__(self, content):
assert isinstance(content, str)
self._normalized = normalize(content)
self._hash = hash(self._normalized)
def __eq__(self, other):
return isinstance(other, StringValue) and self.normalized == other.normalized
def __hash__(self):
return self._hash
def __str__(self):
return 'S' + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
return self.normalized == other.normalized
class NumberValue(Value):
def __init__(self, amount, original_string=None):
assert isinstance(amount, (int, float))
if abs(amount - round(amount)) < 1e-6:
self._amount = int(amount)
else:
self._amount = float(amount)
if not original_string:
self._normalized = str(self._amount)
else:
self._normalized = normalize(original_string)
self._hash = hash(self._amount)
@property
def amount(self):
return self._amount
def __eq__(self, other):
return isinstance(other, NumberValue) and self.amount == other.amount
def __hash__(self):
return self._hash
def __str__(self):
return 'N({})'.format(self.amount) + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, NumberValue):
return abs(self.amount - other.amount) < 1e-6
return False
@staticmethod
def parse(text):
"""Try to parse into a number.
Return:
the number (int or float) if successful; otherwise None.
"""
try:
return int(text)
except ValueError:
try:
amount = float(text)
assert not isnan(amount) and not isinf(amount)
return amount
except ValueError:
return None
class DateValue(Value):
def __init__(self, year, month, day, original_string=None):
"""Create a new DateValue. Placeholders are marked as -1."""
assert isinstance(year, int)
assert isinstance(month, int) and (month == -1 or 1 <= month <= 12)
assert isinstance(day, int) and (day == -1 or 1 <= day <= 31)
assert not (year == month == day == -1)
self._year = year
self._month = month
self._day = day
if not original_string:
self._normalized = '{}-{}-{}'.format(
year if year != -1 else 'xx',
month if month != -1 else 'xx',
day if day != '-1' else 'xx',
)
else:
self._normalized = normalize(original_string)
self._hash = hash((self._year, self._month, self._day))
@property
def ymd(self):
return (self._year, self._month, self._day)
def __eq__(self, other):
return isinstance(other, DateValue) and self.ymd == other.ymd
def __hash__(self):
return self._hash
def __str__(self):
return ('D(%d,%d,%d)' % (self._year, self._month, self._day)) + str(
[self._normalized]
)
__repr__ = __str__
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, DateValue):
return self.ymd == other.ymd
return False
@staticmethod
def parse(text):
"""Try to parse into a date.
Return:
tuple (year, month, date) if successful; otherwise None.
"""
try:
ymd = text.lower().split('-')
assert len(ymd) == 3
year = -1 if ymd[0] in ('xx', 'xxxx') else int(ymd[0])
month = -1 if ymd[1] == 'xx' else int(ymd[1])
day = -1 if ymd[2] == 'xx' else int(ymd[2])
assert not (year == month == day == -1)
assert month == -1 or 1 <= month <= 12
assert day == -1 or 1 <= day <= 31
return (year, month, day)
except:
return None
# Value Instantiation
def to_value(original_string, corenlp_value=None):
"""Convert the string to Value object.
Args:
original_string (basestring): Original string
corenlp_value (basestring): Optional value returned from CoreNLP
Returns:
Value
"""
if isinstance(original_string, Value):
# Already a Value
return original_string
if not corenlp_value:
corenlp_value = original_string
# Number?
amount = NumberValue.parse(corenlp_value)
if amount is not None:
return NumberValue(amount, original_string)
# Date?
ymd = DateValue.parse(corenlp_value)
if ymd is not None:
if ymd[1] == ymd[2] == -1:
return NumberValue(ymd[0], original_string)
else:
return DateValue(ymd[0], ymd[1], ymd[2], original_string)
# String.
return StringValue(original_string)
def to_value_list(original_strings, corenlp_values=None):
"""Convert a list of strings to a list of Values
Args:
original_strings (list[basestring])
corenlp_values (list[basestring or None])
Returns:
list[Value]
"""
assert isinstance(original_strings, (list, tuple, set))
if corenlp_values is not None:
assert isinstance(corenlp_values, (list, tuple, set))
assert len(original_strings) == len(corenlp_values)
return list(
set(to_value(x, y) for (x, y) in zip(original_strings, corenlp_values))
)
else:
return list(set(to_value(x) for x in original_strings))
# Check the Predicted Denotations
def check_denotation(target_values, predicted_values):
"""Return True if the predicted denotation is correct.
Args:
target_values (list[Value])
predicted_values (list[Value])
Returns:
bool
"""
# Check size
if len(target_values) != len(predicted_values):
return False
# Check items
for target in target_values:
if not any(target.match(pred) for pred in predicted_values):
return False
return True
# Batch Mode
def tsv_unescape(x):
"""Unescape strings in the TSV file.
Escaped characters include:
newline (0x10) -> backslash + n
vertical bar (0x7C) -> backslash + p
backslash (0x5C) -> backslash + backslash
Args:
x (str or unicode)
Returns:
a unicode
"""
return x.replace(r'\n', '\n').replace(r'\p', '|').replace('\\\\', '\\')
def tsv_unescape_list(x):
"""Unescape a list in the TSV file.
List items are joined with vertical bars (0x5C)
Args:
x (str or unicode)
Returns:
a list of unicodes
"""
return [tsv_unescape(y) for y in x.split('|')]
from ...smp import *
from .multiple_choice import extract_answer_from_item
import numpy as np
import re
FAIL_MSG = 'Failed to obtain answer via API.'
DURATIONS = [
'short',
'medium',
'long',
]
DOMAINS = [
'Knowledge',
'Film & Television',
'Sports Competition',
'Artistic Performance',
'Life Record',
'Multilingual'
]
SUB_CATEGORIES = [
'Humanity & History',
'Literature & Art',
'Biology & Medicine',
'Finance & Commerce',
'Astronomy',
'Geography',
'Law',
'Life Tip',
'Technology',
'Animation',
'Movie & TV Show',
'Documentary',
'News Report',
'Esports',
'Basketball',
'Football',
'Athletics',
'Other Sports',
'Stage Play',
'Magic Show',
'Variety Show',
'Acrobatics',
'Handicraft',
'Food',
'Fashion',
'Daily Life',
'Travel',
'Pet & Animal',
'Exercise',
'Multilingual'
]
TASK_CATEGORIES = [
'Temporal Perception',
'Spatial Perception',
'Attribute Perception',
'Action Recognition',
'Object Recognition',
'OCR Problems',
'Counting Problem',
'Temporal Reasoning',
'Spatial Reasoning',
'Action Reasoning',
'Object Reasoning',
'Information Synopsis',
]
def get_dimension_rating(data_path):
data = load(data_path)
duration_rating = {k: {} for k in DURATIONS}
for duration in DURATIONS + ['overall']:
duration_rating[duration] = {
'overall': '',
'domain': {k: [] for k in DOMAINS},
'sub_category': {k: [] for k in SUB_CATEGORIES},
'task_type': {k: [] for k in TASK_CATEGORIES}
}
for i in range(len(data)):
domain = data.iloc[i]['domain']
sub_ctg = data.iloc[i]['sub_category']
task_ctg = data.iloc[i]['task_type']
duration = data.iloc[i]['duration']
duration_rating[duration]['domain'][domain].append(data.iloc[i]['score'])
duration_rating[duration]['sub_category'][sub_ctg].append(data.iloc[i]['score'])
duration_rating[duration]['task_type'][task_ctg].append(data.iloc[i]['score'])
duration_rating['overall']['domain'][domain].append(data.iloc[i]['score'])
duration_rating['overall']['sub_category'][sub_ctg].append(data.iloc[i]['score'])
duration_rating['overall']['task_type'][task_ctg].append(data.iloc[i]['score'])
for duration in DURATIONS + ['overall']:
overall_res_dur = f'{np.mean([x for x in sum(duration_rating[duration]["domain"].values(), []) if x >= 0]):.3f}'
duration_rating[duration]['overall'] = overall_res_dur
for domain in DOMAINS:
domain_res_dur = f'{np.mean([x for x in duration_rating[duration]["domain"][domain] if x >= 0]):.3f}'
duration_rating[duration]['domain'][domain] = domain_res_dur
for sub_ctg in SUB_CATEGORIES:
sub_res_dur = f'{np.mean([x for x in duration_rating[duration]["sub_category"][sub_ctg] if x >= 0]):.3f}'
duration_rating[duration]['sub_category'][sub_ctg] = sub_res_dur
for task_ctg in TASK_CATEGORIES:
task_res_dur = f'{np.mean([x for x in duration_rating[duration]["task_type"][task_ctg] if x >= 0]):.3f}'
duration_rating[duration]['task_type'][task_ctg] = task_res_dur
return duration_rating
def extract_option(model, input_item, dataset_name):
options = input_item['question'].split('\n')[1:]
for id, option in enumerate(options):
option_id = chr(ord('A') + id) + '.'
if option.find(option_id) >= 0:
input_item[chr(ord('A') + id)] = option[option.find(option_id) + len(option_id):].strip('. \n')
return extract_answer_from_item(model, input_item, dataset_name)['opt']
def extract_characters_regex(s):
s = s.strip()
answer_prefixes = [
'The best answer is',
'The correct answer is',
'The answer is',
'The answer',
'The best option is'
'The correct option is',
'Best answer:'
'Best option:',
'Answer:',
'Option:',
]
for answer_prefix in answer_prefixes:
s = s.replace(answer_prefix, '')
if len(s.split()) > 10 and not re.search('[ABCD]', s):
return ''
matches = re.search(r'[ABCD]', s)
if matches is None:
return ''
return matches[0]
# Copyright (c) OpenMMLab. All rights reserved.
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
# Copyright (c) 2014, Aishwarya Agrawal
from ...smp import *
from typing import Optional
def _process_digit_article(inText):
outText = []
tempText = inText.lower().split()
articles = ['a', 'an', 'the']
manualMap = {
'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10',
}
contractions = {
'aint': "ain't",
'arent': "aren't",
'cant': "can't",
'couldve': "could've",
'couldnt': "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
'didnt': "didn't",
'doesnt': "doesn't",
'dont': "don't",
'hadnt': "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
'hasnt': "hasn't",
'havent': "haven't",
'hed': "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
'hes': "he's",
'howd': "how'd",
'howll': "how'll",
'hows': "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
'Im': "I'm",
'Ive': "I've",
'isnt': "isn't",
'itd': "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
'itll': "it'll",
"let's": "let's",
'maam': "ma'am",
'mightnt': "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
'mightve': "might've",
'mustnt': "mustn't",
'mustve': "must've",
'neednt': "needn't",
'notve': "not've",
'oclock': "o'clock",
'oughtnt': "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
'shant': "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
'shouldve': "should've",
'shouldnt': "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": 'somebodyd',
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
'somebodyll': "somebody'll",
'somebodys': "somebody's",
'someoned': "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
'someonell': "someone'll",
'someones': "someone's",
'somethingd': "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
'somethingll': "something'll",
'thats': "that's",
'thered': "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
'therere': "there're",
'theres': "there's",
'theyd': "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
'theyll': "they'll",
'theyre': "they're",
'theyve': "they've",
'twas': "'twas",
'wasnt': "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
'weve': "we've",
'werent': "weren't",
'whatll': "what'll",
'whatre': "what're",
'whats': "what's",
'whatve': "what've",
'whens': "when's",
'whered': "where'd",
'wheres': "where's",
'whereve': "where've",
'whod': "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
'wholl': "who'll",
'whos': "who's",
'whove': "who've",
'whyll': "why'll",
'whyre': "why're",
'whys': "why's",
'wont': "won't",
'wouldve': "would've",
'wouldnt': "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
'yall': "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
'youd': "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
'youll': "you'll",
'youre': "you're",
'youve': "you've",
}
for word in tempText:
word = manualMap.setdefault(word, word)
if word not in articles:
outText.append(word)
for wordId, word in enumerate(outText):
if word in contractions:
outText[wordId] = contractions[word]
outText = ' '.join(outText)
return outText
def hit_calculate(result, dataset_name, anls_threshold=0.5):
if listinstr(['TextVQA'], dataset_name):
return [np.mean(x['match']) for x in result]
elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
elif listinstr(['ChartQA', 'OCRVQA'], dataset_name):
return [np.max(x['match']) for x in result]
else: # default using vqa_score to calculate score
return [np.mean(x['match']) for x in result]
# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
prediction: str,
max_relative_change: float = 0.05) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
Args:
target: Target string.
prediction: Predicted string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
def _to_float(text: str) -> Optional[float]:
try:
if text.endswith('%'):
# Convert percentages to floats.
return float(text.rstrip('%')) / 100.0
else:
return float(text)
except ValueError:
return None
prediction = str(prediction)
target = str(target)
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float - target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def anls_compute(groundtruth, prediction):
gt_answer = ' '.join(groundtruth.strip().lower().split())
det_answer = ' '.join(prediction.strip().lower().split())
dist = levenshtein_distance(gt_answer, det_answer)
length = max(len(groundtruth.upper()), len(prediction.upper()))
values = 0.0 if length == 0 else float(dist) / float(length)
return values
def process_answer(answer):
answer = answer.replace('\n', ' ')
answer = answer.replace('\t', ' ')
answer = answer.strip()
answer = process_punctuation(answer)
answer = _process_digit_article(answer)
return answer
def process_line(line, method='vqa_score'):
ret = {}
if istype(line['answer'], list):
answers = eval(line['answer'])
else:
answers = [line['answer']]
if method == 'vqa_score':
ret['gt'] = [process_answer(x) for x in answers]
ret['pred'] = process_answer(line['prediction'])
ret['match'] = []
for current_idx, gtAnsDatum in enumerate(ret['gt']):
otherGTAns = [
item for ret_gt_idx, item in enumerate(ret['gt'])
if ret_gt_idx != current_idx
]
matchingAns = [
item for item in otherGTAns if item == ret['pred']
]
acc = min(1, float(len(matchingAns)) / 3)
ret['match'].append(acc)
elif method == 'anls':
ret['gt'] = answers
ret['pred'] = line['prediction']
ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']]
elif method == 'relaxed_accuracy':
ret['gt'] = answers
ret['pred'] = line['prediction'].strip()
ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']]
elif method == 'accuracy':
ret['gt'] = answers
ret['pred'] = line['prediction'].strip()
ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']]
else: # default using vqa_score to calculate score
ret['gt'] = [process_answer(x) for x in answers]
ret['pred'] = process_answer(line['prediction'])
ret['match'] = [x == ret['pred'] for x in ret['gt']]
return ret
from ...smp import *
def AMBER_rating(data_file):
data = load(data_file)
stats = defaultdict(dict)
lt = len(data)
category_mapping = {
'discriminative-attribute-state': 'Attribute',
'discriminative-attribute-number': 'Attribute',
'discriminative-attribute-action': 'Attribute',
'discriminative-hallucination': 'Existence',
'discriminative-relation': 'Relation',
'relation': 'Relation'
}
for i in range(lt):
item = data.iloc[i]
category = item['category']
image_path = item['image_path']
score = item['score']
new_category = category_mapping.get(category, category)
if image_path not in stats[new_category]:
stats[new_category][image_path] = []
stats[new_category][image_path].append(score)
def acc(key):
res = stats[key]
values = []
for val in res.values():
values.extend(val)
return np.mean(values) * 100
scores = {}
for k in stats:
scores[k] = acc(k)
scores['Avg ACC'] = np.mean(list(scores.values()))
ret = d2df(scores)
return ret
def MME_rating(data_file):
data = load(data_file)
stats = defaultdict(dict)
lt = len(data)
for i in range(lt):
item = data.iloc[i]
category = item['category']
image_path = item['image_path']
score = item['score']
if image_path not in stats[category]:
stats[category][image_path] = []
stats[category][image_path].append(score)
def acc(key, mode='normal'):
res = stats[key]
values = []
for val in res.values():
if mode == 'normal':
values.extend(val)
elif mode == 'plus':
values.append(val[0] * val[1])
return np.mean(values) * 100
scores = {}
for k in stats:
scores[k] = acc(k) + acc(k, 'plus')
super_cates = dict(
perception=[
'OCR', 'artwork', 'celebrity', 'color', 'count', 'existence',
'landmark', 'position', 'posters', 'scene'
],
reasoning=['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation']
)
ret = {}
for sc, cate_list in super_cates.items():
base = 0
for c in cate_list:
base += scores[c]
ret[sc] = base
ret.update(scores)
ret = d2df(ret)
return ret
def Hallusion_rating(data_file):
def calc_fAcc(data):
res = defaultdict(list)
lt = len(data)
for i in range(lt):
line = data.iloc[i]
res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score'])
return np.mean([np.all(x) for x in res.values()]) * 100
def calc_qAcc(data):
res = defaultdict(list)
lt = len(data)
for i in range(lt):
line = data.iloc[i]
res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score'])
return np.mean([np.all(x) for x in res.values()]) * 100
def calc_aAcc(data):
return np.mean(data['score']) * 100
data = load(data_file)
data['set_id'] = [x.split('_')[3] for x in data['index']]
data['figure_id'] = [x.split('_')[4] for x in data['index']]
data['question_id'] = [x.split('_')[5] for x in data['index']]
res = dict(split=[], aAcc=[], fAcc=[], qAcc=[])
res['split'].append('Overall')
res['aAcc'].append(calc_aAcc(data))
res['fAcc'].append(calc_fAcc(data))
res['qAcc'].append(calc_qAcc(data))
if 'category' in data:
cates = list(set(data['category']))
for c in cates:
sub = data[data['category'] == c]
res['split'].append(c)
res['aAcc'].append(calc_aAcc(sub))
res['fAcc'].append(calc_fAcc(sub))
res['qAcc'].append(calc_qAcc(sub))
if 'l2-category' in data:
cates = list(set(data['l2-category']))
for c in cates:
sub = data[data['l2-category'] == c]
res['split'].append(c)
res['aAcc'].append(calc_aAcc(sub))
res['fAcc'].append(calc_fAcc(sub))
res['qAcc'].append(calc_qAcc(sub))
ret = pd.DataFrame(res)
return ret
def POPE_rating(data_file):
def cal_f1_score(y_true, y_pred):
tp = sum((y_true == 1) & (y_pred == 1))
fp = sum((y_true == 0) & (y_pred == 1))
fn = sum((y_true == 1) & (y_pred == 0))
precision = tp / (tp + fp) if (tp + fp) != 0 else 0
recall = tp / (tp + fn) if (tp + fn) != 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
return f1_score, precision, recall
data = load(data_file)
data = data.assign(category=data['category'].str.split(',')).explode('category')
data['index'] = range(len(data))
res = dict(split=[], Overall=[], acc=[], precision=[], recall=[])
y_true = np.array([1 if i == 'Yes' else 0 for i in data['answer']])
y_pred = np.array([1 if i == 'Yes' else 0 for i in data['extracted']])
f1_score, precision, recall = cal_f1_score(y_true, y_pred)
res['split'].append('Overall')
res['Overall'].append(f1_score * 100)
res['acc'].append(np.mean(data['score']) * 100)
res['precision'].append(precision * 100)
res['recall'].append(recall * 100)
if 'category' in data:
cates = list(set(data['category']))
cates = [c for c in cates if not pd.isna(c)]
for c in cates:
sub = data[data['category'] == c]
y_true = np.array([1 if i == 'Yes' else 0 for i in sub['answer']])
y_pred = np.array([1 if i == 'Yes' else 0 for i in sub['extracted']])
f1_score, precision, recall = cal_f1_score(y_true, y_pred)
res['split'].append(c)
res['Overall'].append(f1_score * 100)
res['acc'].append(np.mean(sub['score']) * 100)
res['precision'].append(precision * 100)
res['recall'].append(recall * 100)
ret = pd.DataFrame(res)
return ret
def default_rating(data_file):
data = load(data_file)
res = {}
res['Overall'] = np.mean(data['score']) * 100
if 'category' in data:
cates = list(set(data['category']))
cates = [c for c in cates if not pd.isna(c)]
cates.sort()
for c in cates:
sub = data[data['category'] == c]
res[c] = np.mean(sub['score']) * 100
if 'l2-category' in data:
cates = list(set(data['l2-category']))
cates = [c for c in cates if not pd.isna(c)]
cates.sort()
for c in cates:
sub = data[data['l2-category'] == c]
res[c] = np.mean(sub['score']) * 100
ret = d2df(res)
return ret
def YOrN_match_prompt(line):
tmpl = (
'You are an AI assistant who will help me to match an answer with two options of a question. '
'The options are only Yes / No. '
'You are provided with a question and an answer, '
'and you need to find which option (Yes / No) is most similar to the answer. '
'If the meaning of all options are significantly different from the answer, output Unknown. '
'Your should output a single word among the following 3 choices: Yes, No, Unknown.\n'
'Example 1: \n'
"Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is 'Hello'.\nYour output: Yes\n"
'Example 2: \n'
"Question: Is the word in this image 'Hello'?\n"
"Answer: The word in this image is not 'Hello'.\nYour output: No\n"
'Example 3: \n'
'Question: {}?\nAnswer: {}\nYour output: '
)
return tmpl.format(line['question'], line['prediction'])
def YOrN_Extraction(output):
s = output.lower()
words = process_punctuation(s).split()
if 'yes' in words and 'no' not in words:
return 'Yes'
if 'yes' not in words and 'no' in words:
return 'No'
return 'Unknown'
def YOrN_auxeval(model, line):
prompt = YOrN_match_prompt(line)
retry = 5
for i in range(retry):
output = model.generate(prompt, temperature=0.5 * i)
ans = YOrN_Extraction(output)
if ans != 'Unknown':
return ans
return 'Unknown'
import uuid
from functools import partial
from .image_base import ImageBaseDataset
from ..smp import *
rouge = None
nlp_en = None
nlp_zh = None
nlp = None
def initialize():
import evaluate
import spacy
global rouge, nlp_en, nlp_zh, nlp
try:
rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4()))
except:
warnings.warn('Please first `pip install rouge_score`.')
try:
nlp_en = spacy.load('en_core_web_sm')
except:
warnings.warn('Will automatically download en_core_web_sm via spacy.')
spacy.cli.download('en_core_web_sm')
nlp_en = spacy.load('en_core_web_sm')
try:
nlp_zh = spacy.load('zh_core_web_sm')
except:
warnings.warn('Will automatically download zh_core_web_sm via spacy.')
spacy.cli.download('zh_core_web_sm')
nlp_zh = spacy.load('zh_core_web_sm')
nlp = {'en': nlp_en, 'zh': nlp_zh}
def rough_filter(answer_text):
if "I can't" in answer_text:
return False
elif 'I cannot' in answer_text:
return False
elif 'sorry' in answer_text.lower():
return False
if '无法' in answer_text:
return False
elif '抱歉' in answer_text:
return False
else:
return True
def zero_template(crossed_text):
return {
'crossed_text': crossed_text,
'max_sim_val': 0,
'max_sim_string': '',
'precision': 0,
'recall': 0,
'f1': 0,
'jaccard': 0,
'rouge1': 0,
'exact_match': 0,
}
def tokenize(text, language):
"""
Tokenize the text and return the tokens.
Parameters:
text (str): The text to tokenize.
language (str): The language of the text.
Returns:
list: The list of tokens.
"""
assert language in ['en', 'zh']
nlp_language = nlp[language]
processed_text = nlp_language(text)
return [token.text for token in processed_text]
def find_best_match(needle, hay, language, rouge):
"""
Finds the best matching n-gram in the haystack for the given needle.
Parameters:
needle (str): The string to find.
hay (str): The text to search within.
Returns:
tuple: The highest similarity value and the best matching string.
"""
assert language in ['en', 'zh']
from nltk.util import ngrams
from difflib import SequenceMatcher as SM
tokens_hay = tokenize(hay, language)
tokens_needle = tokenize(needle, language)
splitter = '' if language == 'zh' else ' '
ngrams_ = ngrams(tokens_hay, len(tokens_needle))
max_sim_val = 0
max_sim_string = ''
max_sim_ngram = []
tokens_needle_set = set(tokens_needle)
ngrams_hasjoint = [
ngram
for ngram in ngrams_
if not set(ngram).isdisjoint(tokens_needle_set)
]
for ngram in ngrams_hasjoint:
hay_ngram = splitter.join(ngram)
similarity = SM(None, hay_ngram, needle).ratio()
if similarity > max_sim_val:
max_sim_val = similarity
max_sim_string = hay_ngram
max_sim_ngram = ngram
# Evaluate
if len(max_sim_ngram) == 0:
return {
'crossed_text': needle,
'max_sim_val': 0,
'max_sim_string': '',
'precision': 0,
'recall': 0,
'f1': 0,
'jaccard': 0,
'rouge1': 0,
'exact_match': 0,
}
pred_set = set(max_sim_ngram)
ref_set = set(tokens_needle)
correct_tokens = pred_set.intersection(ref_set)
len_correct_tokens = len(correct_tokens)
precision = len_correct_tokens / len(pred_set)
recall = len_correct_tokens / len(ref_set)
if (precision + recall) == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)
union = pred_set.union(ref_set)
jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
rouge_1 = rouge.compute(
predictions=[max_sim_string],
references=[needle],
tokenizer=partial(tokenize, language=language),
rouge_types=['rouge1'],
)['rouge1']
exact_match = float(list(max_sim_ngram) == list(tokens_needle))
out = {
'crossed_text': needle,
'max_sim_string': max_sim_string,
'max_sim_val': max_sim_val,
'precision': precision,
'recall': recall,
'f1': f1,
'jaccard': jaccard,
'rouge1': rouge_1,
'exact_match': exact_match,
}
return out
def process_match_single_new(
image_id, prediction, answer, language, progress
):
"""
process the inference results for a single image and calculate the metrics
Parameters:
image_id (int): The image id (question id).
prediction (str): The prediction text.
answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image.
language (str): The language of the text. Can be "en" or "zh".
rouge (rouge): The rouge metric object.
progress (multiprocessing.Queue): The progress queue.
Returns:
tuple: The image id (question_id, int) and the result per id (dict of dict of dict).
"""
result_per_id = {image_id: {}}
if isinstance(answer, str):
answer = eval(answer)
assert isinstance(answer, list)
result = prediction.split('Assistant: ')[-1]
for i, crossed_text in enumerate(answer):
if rough_filter(result):
find_best_match_result = find_best_match(
crossed_text, result, language, rouge
)
if i == 0:
result_per_id[image_id] = {str(i): find_best_match_result}
else:
result_per_id[image_id][str(i)] = find_best_match_result
else:
if i == 0:
result_per_id[image_id] = {str(i): zero_template(crossed_text)}
else:
result_per_id[image_id][str(i)] = zero_template(crossed_text)
progress.put(1)
return image_id, result_per_id
class VCRDataset(ImageBaseDataset):
TYPE = 'VQA'
URL_PREFIX = 'https://huggingface.co/datasets/vcr-org'
DATASET_URL = {
'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv',
'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv',
'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv',
'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv',
'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv',
'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv',
'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv',
'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv',
'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv',
'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv',
'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv',
'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv',
}
DATASET_MD5 = {
'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261',
'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0',
'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1',
'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4',
'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958',
'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e',
'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048',
'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5',
'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea',
'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7',
'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b',
'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c',
}
def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True):
super().__init__(dataset, skip_noimg)
initialize()
self.language = 'en' if 'EN' in dataset else 'zh'
self.difficulty = 'easy' if 'EASY' in dataset else 'hard'
# def build_prompt(self, line):
# msgs = super().build_prompt(line)
# assert msgs[-1]['type'] == 'text'
# if self.language == 'zh':
# msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。'
# else:
# msgs[-1]['value'] += ('What is the covered texts in the image? '
# 'Please restore the covered texts without outputting the explanations.')
# return msgs
def evaluate(self, eval_file, **judge_kwargs):
import multiprocessing
vcr_score_list = {'Exact_Match': [], 'Jaccard': []}
vcr_score = {'Exact_Match': 0, 'Jaccard': 0}
logger = get_logger('Evaluation')
data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
pool = multiprocessing.Pool()
manager = multiprocessing.Manager()
progress_queue = manager.Queue()
results = []
overall_results = {str(image_id): {} for image_id in range(len(lines))}
for instance_id, instance in enumerate(lines):
results.append(
pool.apply_async(
process_match_single_new,
args=(
str(instance_id),
instance['prediction'],
instance['answer'],
self.language,
progress_queue,
),
)
)
pool.close()
# Display progress bar
for _ in tqdm(range(len(results))):
progress_queue.get()
pool.join()
# Merging results into overall_result
for result in results:
image_id, result_per_id = result.get()
overall_results[str(image_id)].update(result_per_id[image_id])
for blank_id_str in result_per_id[image_id].keys():
vcr_score_list['Exact_Match'].append(
result_per_id[image_id][blank_id_str]['exact_match']
)
vcr_score_list['Jaccard'].append(
result_per_id[image_id][blank_id_str]['jaccard']
)
vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match'])
vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard'])
results_out = {
k: v for i in range(len(results)) for k, v in results[i].get()[1].items()
}
results_with_metrics = {
'Exact_Match': vcr_score['Exact_Match'],
'Jaccard': vcr_score['Jaccard'],
'Predictions': results_out,
}
score_pth = eval_file.replace(
'.xlsx', f'{self.language}_{self.difficulty}_score.json'
)
dump(results_with_metrics, score_pth)
logger.info(
f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}'
)
logger.info('Score: ')
for key, value in vcr_score.items():
logger.info('{}:{}'.format(key, value))
from abc import abstractmethod
from ..smp import *
class VideoBaseDataset:
MODALITY = 'VIDEO'
def __init__(self,
dataset='MMBench-Video',
pack=False):
try:
import decord
except:
warnings.warn('Please install decord via `pip install decord`.')
self.dataset_name = dataset
ret = self.prepare_dataset(dataset)
assert ret is not None
lmu_root = LMUDataRoot()
self.frame_root = osp.join(lmu_root, 'images', dataset)
os.makedirs(self.frame_root, exist_ok=True)
self.frame_tmpl = 'frame-{}-of-{}.jpg'
self.data_root = ret['root']
self.data_file = ret['data_file']
self.data = load(self.data_file)
assert 'question' in self.data and 'video' in self.data
videos = list(set(self.data['video']))
videos.sort()
self.videos = videos
self.pack = pack
def __len__(self):
return len(self.videos) if self.pack else len(self.data)
def __getitem__(self, idx):
if self.pack:
assert idx < len(self.videos)
sub_data = self.data[self.data['video'] == self.videos[idx]]
return sub_data
else:
assert idx < len(self.data)
return dict(self.data.iloc[idx])
def frame_paths(self, video, num_frames=8):
frame_root = osp.join(self.frame_root, video)
os.makedirs(frame_root, exist_ok=True)
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
def save_video_frames(self, video, num_frames=8):
frame_paths = self.frame_paths(video, num_frames)
flag = np.all([osp.exists(p) for p in frame_paths])
if flag:
return frame_paths
vid_path = osp.join(self.data_root, video + '.mp4')
vid = decord.VideoReader(vid_path)
step_size = len(vid) / (num_frames + 1)
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths
# Return a list of dataset names that are supported by this class, can override
@classmethod
def supported_datasets(cls):
return ['MMBench-Video', 'Video-MME', 'MVBench', 'MVBench_MP4']
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
@abstractmethod
def evaluate(self, eval_file, **judge_kwargs):
pass
@abstractmethod
def build_prompt(self, idx, num_frames=8):
pass
@abstractmethod
def prepare_dataset(self, dataset):
# The prepare_dataset function should return a dictionary containing:
# `root` (directory that containing video files)
# `data_file` (the TSV dataset file)
pass
from huggingface_hub import snapshot_download
from ..smp import *
from .video_base import VideoBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
FAIL_MSG = 'Failed to obtain answer via API.'
def unwrap_hf_pkl(pth, suffix='.mp4'):
base_dir = os.path.join(pth, 'video_pkl/')
target_dir = os.path.join(pth, 'video/')
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
pickle_files.sort()
if not os.path.exists(target_dir):
os.makedirs(target_dir, exist_ok=True)
for pickle_file in pickle_files:
with open(pickle_file, 'rb') as file:
video_data = pickle.load(file)
# For each video file in the pickle file, write its contents to a new mp4 file
for video_name, video_content in video_data.items():
output_path = os.path.join(target_dir, f'{video_name}{suffix}')
with open(output_path, 'wb') as output_file:
output_file.write(video_content)
print('The video file has been restored and stored from the pickle file.')
else:
print('The video file already exists.')
class VideoMME(VideoBaseDataset):
MD5 = '85bdd91f9b29a99354c23b97ab7c113c'
SYS = ''
FRAMES_TMPL_NOSUB = """
These are the frames of a video. \
Select the best answer to the following multiple-choice question based on the video. \
Respond with only the letter (A, B, C, or D) of the correct option.
"""
FRAMES_TMPL_SUB = """
These are the frames of a video. \
This video's subtitles are listed below:
{}
Select the best answer to the following multiple-choice question based on the video. \
Respond with only the letter (A, B, C, or D) of the correct option.
"""
TYPE = 'Video-MCQ'
def __init__(self, dataset='Video-MME', use_subtitle=False):
super().__init__(dataset=dataset)
self.use_subtitle = use_subtitle
self.dataset_name = dataset
@classmethod
def supported_datasets(cls):
return ['Video-MME']
def prepare_dataset(self, dataset_name='Video-MME', repo_id='lmms-lab/Video-MME'):
def check_integrity(pth):
data_file = osp.join(pth, f'{dataset_name}.tsv')
if not os.path.exists(data_file):
return False
if md5(data_file) != self.MD5:
return False
data = load(data_file)
for video_pth in data['video_path']:
if not osp.exists(osp.join(pth, video_pth)):
return False
return True
cache_path = get_cache_path(repo_id)
if cache_path is not None and check_integrity(cache_path):
dataset_path = cache_path
else:
def unzip_hf_zip(pth):
import zipfile
base_dir = pth
target_dir = os.path.join(pth, 'video/')
zip_files = [
os.path.join(base_dir, file) for file in os.listdir(base_dir)
if file.endswith('.zip') and file.startswith('video')
]
zip_files.sort()
if not os.path.exists(target_dir):
os.makedirs(target_dir, exist_ok=True)
for zip_file in zip_files:
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
for member in zip_ref.namelist():
# Check if the member is a file (not a directory)
if not member.endswith('/'):
# Extract the file to the specified directory
source = zip_ref.open(member)
target = open(os.path.join(target_dir, os.path.basename(member)), 'wb')
with source, target:
target.write(source.read())
print('The video file has been restored and stored from the zip file.')
else:
print('The video file already exists.')
subtitle_zip_file = os.path.join(base_dir, 'subtitle.zip')
subtitle_target_dir = os.path.join(base_dir, 'subtitle')
if not os.path.exists(subtitle_target_dir):
os.makedirs(subtitle_target_dir, exist_ok=True)
with zipfile.ZipFile(subtitle_zip_file, 'r') as zip_ref:
for member in zip_ref.namelist():
# Check if the member is a file (not a directory)
if not member.endswith('/'):
# Extract the file to the specified directory
source = zip_ref.open(member)
target = open(os.path.join(subtitle_target_dir, os.path.basename(member)), 'wb')
with source, target:
target.write(source.read())
print('The subtitle file has been restored and stored from the zip file.')
else:
print('The subtitle file already exists.')
def generate_tsv(pth):
data_file = osp.join(pth, f'{dataset_name}.tsv')
if os.path.exists(data_file) and md5(data_file) == self.MD5:
return
data_file = pd.read_parquet(os.path.join(pth, 'videomme/test-00000-of-00001.parquet'))
data_file = data_file.assign(index=range(len(data_file)))
data_file['video'] = data_file['videoID']
data_file['video_path'] = data_file['videoID'].apply(lambda x: f'./video/{x}.mp4')
data_file['subtitle_path'] = data_file['videoID'].apply(lambda x: f'./subtitle/{x}.srt')
data_file['candidates'] = data_file['options'].apply(lambda x: x.tolist())
data_file = data_file[['index', 'video', 'video_path', 'duration', 'domain', 'candidates',
'sub_category', 'task_type', 'subtitle_path', 'question', 'answer']]
data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
unzip_hf_zip(dataset_path)
generate_tsv(dataset_path)
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
return dict(data_file=data_file, root=dataset_path)
def save_video_frames(self, video, num_frames=8):
vid_path = osp.join(self.data_root, 'video', video + '.mp4')
vid = decord.VideoReader(vid_path)
step_size = len(vid) / (num_frames + 1)
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
video_info = {
'fps': vid.get_avg_fps(),
'n_frames': len(vid),
}
frame_paths = self.frame_paths(video, num_frames)
flag = np.all([osp.exists(p) for p in frame_paths])
if not flag:
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths, indices, video_info
def save_video_into_images(self, line, num_frames=8):
frame_paths, indices, video_info = self.save_video_frames(line['video'], num_frames)
return frame_paths
def build_prompt(self, line, num_frames, video_llm):
if isinstance(line, int):
assert line < len(self)
line = self.data.iloc[line]
frames, indices, video_info = self.save_video_frames(line['video'], num_frames)
if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])):
import pysubs2
subs = pysubs2.load(osp.join(self.data_root, line['subtitle_path']), encoding='utf-8')
subtitles = []
for seleced_frame_id in indices:
sub_text = ''
cur_time = pysubs2.make_time(fps=video_info['fps'], frames=seleced_frame_id)
for sub in subs:
if sub.start < cur_time and sub.end > cur_time:
sub_text = sub.text.replace('\\N', ' ')
break
if sub_text.strip():
subtitles.append(sub_text)
subtitles = '\n'.join(subtitles)
else:
subtitles = ''
message = [dict(type='text', value=self.SYS)]
if video_llm:
message.append(dict(type='video', value=osp.join(self.data_root, 'video', line['video'] + '.mp4')))
else:
for im in frames:
message.append(dict(type='image', value=im))
text_prompt = self.FRAMES_TMPL_NOSUB if not self.use_subtitle else self.FRAMES_TMPL_SUB.format(subtitles)
message.append(dict(type='text', value=text_prompt))
line['question'] += '\n' + '\n'.join(eval(line['candidates']))
prompt = 'Question: {}\nAnswer: '.format(line['question'])
message.append(dict(type='text', value=prompt))
return message
# It returns a dictionary
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
from .utils.videomme import get_dimension_rating, extract_characters_regex, extract_option
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
tgt_file = eval_file.replace('.xlsx', '_rating.json')
score_file = eval_file.replace('.xlsx', '_score.xlsx')
if not osp.exists(score_file):
model = judge_kwargs.get('model', 'exact_matching')
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
if model == 'exact_matching':
model = None
elif gpt_key_set():
model = build_judge(**judge_kwargs)
if not model.working():
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
warnings.warn(DEBUG_MESSAGE)
model = None
else:
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None
res = {} if not osp.exists(tmp_file) else load(tmp_file)
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
data = load(eval_file)
data_un = data[~pd.isna(data['prediction'])]
for idx in data['index']:
ans = data.loc[data['index'] == idx, 'answer'].values[0]
pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
if extract_characters_regex(pred) == '':
extract_pred = extract_option(
model,
data.loc[data['index'] == idx].to_dict(orient='records')[0],
'Video-MME'
)
data.loc[idx, 'score'] = int(extract_pred == ans)
else:
data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
rejected = [x for x in data['score'] if x == -1]
print(
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
f'failed to obtain the score for another {len(rejected)} questions. '
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
)
dump(data, score_file)
rating = get_dimension_rating(score_file)
dump(rating, tgt_file)
return rating
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
data = dataset.data
if index_set is not None:
data = data[data['index'].isin(index_set)]
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert getattr(model, 'is_api', False)
if hasattr(model, 'set_dump_image'):
model.set_dump_image(dataset.dump_image)
lt, indices = len(data), list(data['index'])
structs = []
for i in range(lt):
item = data.iloc[i]
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
assert hasattr(model, 'build_prompt')
struct = model.build_prompt(item, dataset=dataset_name)
else:
struct = dataset.build_prompt(item)
structs.append(struct)
# structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
gen_func = model.generate
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
dataset_name = dataset.dataset_name
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
res = load(prev_file) if osp.exists(prev_file) else {}
if osp.exists(out_file):
res.update(load(out_file))
rank, world_size = get_rank_and_world_size()
sheet_indices = list(range(rank, len(dataset), world_size))
lt = len(sheet_indices)
data = dataset.data.iloc[sheet_indices]
data_indices = [i for i in data['index']]
# If finished, will exit without building the model
all_finished = True
for i in range(lt):
idx = data.iloc[i]['index']
if idx not in res:
all_finished = False
if all_finished:
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return
# Data need to be inferred
data = data[~data['index'].isin(res)]
lt = len(data)
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
supp = infer_data_api(
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
index_set=set(indices),
api_nproc=api_nproc)
for idx in indices:
assert idx in supp
res.update(supp)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model_name
else:
model.set_dump_image(dataset.dump_image)
for i in tqdm(range(lt)):
idx = data.iloc[i]['index']
if idx in res:
continue
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
else:
struct = dataset.build_prompt(data.iloc[i])
response = model.generate(message=struct, dataset=dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 20 == 0:
dump(res, out_file)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx')
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
if osp.exists(result_file):
if rank == 0:
data = load(result_file)
results = {k: v for k, v in zip(data['index'], data['prediction'])}
if not ignore_failed:
results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
dump(results, prev_file)
if world_size > 1:
dist.barrier()
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
data = dataset.data
for x in data['index']:
assert x in data_all
data['prediction'] = [str(data_all[x]) for x in data['index']]
if 'image' in data:
data.pop('image')
dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
if world_size > 1:
dist.barrier()
return model
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
def chat_mt(model, messages, dataset_name):
assert len(messages) % 2 == 0
nturn = len(messages) // 2
utter_stack = []
predictions = []
for i in range(nturn):
utter = messages[2 * i]
utter_stack.append(utter)
try:
resp = model.chat(utter_stack, dataset=dataset_name)
utter_stack.append(dict(role='assistant', content=resp))
except:
resp = FAIL_MSG
utter_stack.append(dict(role='assistant', content=resp))
predictions.append(resp)
return predictions
# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
data = dataset.data
if index_set is not None:
data = data[data['index'].isin(index_set)]
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert getattr(model, 'is_api', False)
assert hasattr(model, 'chat_inner')
lt, indices = len(data), list(data['index'])
structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
structs = [dict(model=model, messages=struct, dataset_name=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(chat_mt, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
dataset_name = dataset.dataset_name
res = {}
if osp.exists(out_file):
res.update(load(out_file))
rank, world_size = get_rank_and_world_size()
sheet_indices = list(range(rank, len(dataset), world_size))
lt = len(sheet_indices)
data = dataset.data.iloc[sheet_indices]
data_indices = [i for i in data['index']]
# If finished, will exit without building the model
all_finished = True
for i in range(lt):
idx = data.iloc[i]['index']
if idx not in res:
all_finished = False
if all_finished:
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return
# Data need to be inferred
data = data[~data['index'].isin(res)]
lt = len(data)
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert hasattr(model, 'chat_inner')
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
supp = infer_data_api(
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
index_set=set(indices),
api_nproc=api_nproc)
for idx in indices:
assert idx in supp
res.update(supp)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model_name
else:
model.set_dump_image(dataset.dump_image)
for i in tqdm(range(lt)):
idx = data.iloc[i]['index']
if idx in res:
continue
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
else:
struct = dataset.build_prompt(data.iloc[i])
response = chat_mt(model, struct, dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 20 == 0:
dump(res, out_file)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job_mt(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.tsv')
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
data = dataset.data
for x in data['index']:
assert x in data_all
data['prediction'] = [data_all[x] for x in data['index']]
if 'image' in data:
data.pop('image')
dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
return model
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset, nframe=8, pack=False, samples_dict={}, api_nproc=4):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert getattr(model, 'is_api', False)
indices = list(samples_dict.keys())
structs = [dataset.build_prompt(samples_dict[idx], num_frames=nframe,
video_llm=getattr(model, 'VIDEO_LLM', False)) for idx in indices]
packstr = 'pack' if pack else 'nopack'
out_file = f'{work_dir}/{model_name}_{dataset_name}_{nframe}frame_{packstr}_supp.pkl'
res = load(out_file) if osp.exists(out_file) else {}
structs = [s for i, s in zip(indices, structs) if i not in res or res[i] == FAIL_MSG]
indices = [i for i in indices if i not in res or res[i] == FAIL_MSG]
gen_func = model.generate
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
return res
def infer_data(model_name, work_dir, dataset, out_file, nframe=8, pack=False, verbose=False, api_nproc=4):
res = load(out_file) if osp.exists(out_file) else {}
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
sample_indices = list(dataset.videos) if pack else list(dataset.data['index'])
samples = list(dataset.videos) if pack else list(range(len(dataset.data)))
sample_map = {i: s for i, s in zip(sample_indices, samples)}
sample_indices_sub = sample_indices[rank::world_size]
if np.all([idx in res for idx in sample_indices_sub]):
return model_name
sample_indices_subrem = [x for x in sample_indices_sub if x not in res]
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
is_api = getattr(model, 'is_api', False)
if is_api:
assert world_size == 1
supp = infer_data_api(
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
nframe=nframe,
pack=pack,
samples_dict={k: sample_map[k] for k in sample_indices_subrem},
api_nproc=api_nproc)
for k in sample_indices_subrem:
assert k in supp
res.update(supp)
dump(res, out_file)
return model_name
for i, idx in tqdm(enumerate(sample_indices_subrem)):
if idx in res:
continue
# adapt to model frame sample number first
nframe = getattr(model, 'nframe', 0) if getattr(model, 'nframe', 0) > 0 else nframe
# when using video-llm, build prompt returns video+question; otherwise, several frames+question
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(
dataset.data.iloc[sample_map[idx]], dataset=dataset,
num_frames=nframe, video_llm=getattr(model, 'VIDEO_LLM', False)
)
else:
struct = dataset.build_prompt(
sample_map[idx], num_frames=nframe,
video_llm=getattr(model, 'VIDEO_LLM', False)
)
response = model.generate(message=struct, dataset=dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 20 == 0:
dump(res, out_file)
res = {k: res[k] for k in sample_indices_sub}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job_video(
model,
work_dir,
model_name,
dataset,
nframe=8,
pack=False,
verbose=False,
subtitle=False,
api_nproc=4):
dataset_name = dataset.dataset_name
packstr = 'pack' if pack else 'nopack'
rank, world_size = get_rank_and_world_size()
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}_{nframe}frame_{packstr}.xlsx')
if dataset_name == 'Video-MME':
subtitle_str = 'subs' if subtitle else 'nosubs'
result_file = result_file.replace('.xlsx', f'_{subtitle_str}.xlsx')
# Dump Predictions to Prev File if result file exists
if osp.exists(result_file):
return model_name
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}_{nframe}frame_{packstr}.pkl')
if dataset_name == 'Video-MME':
subtitle_str = 'subs' if subtitle else 'nosubs'
tmpl = tmpl.replace('.pkl', f'_{subtitle_str}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model,
work_dir=work_dir,
dataset=dataset,
nframe=nframe,
pack=pack,
out_file=out_file,
verbose=verbose,
api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
meta = dataset.data
if dataset_name == 'MMBench-Video' and pack:
meta, vstats = dataset.load_pack_answers(data_all)
print(f'Statitics of Pack Video Inference: {vstats}')
else:
for x in meta['index']:
assert x in data_all
meta['prediction'] = [str(data_all[x]) for x in meta['index']]
if 'image' in meta:
meta.pop('image')
dump(meta, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
return model
from .file import *
from .vlm import *
from .misc import *
from .log import *
import json
import pickle
import pandas as pd
import os
import csv
import hashlib
import os.path as osp
import time
import numpy as np
import validators
import mimetypes
import multiprocessing as mp
from .misc import toliststr
from .vlm import decode_base64_to_image_file
def decode_img_omni(tup):
root, im, p = tup
images = toliststr(im)
paths = toliststr(p)
if len(images) > 1 and len(paths) == 1:
paths = [osp.splitext(p)[0] + f'_{i}' + osp.splitext(p)[1] for i in range(len(images))]
assert len(images) == len(paths)
paths = [osp.join(root, p) for p in paths]
for p, im in zip(paths, images):
if osp.exists(p):
continue
if isinstance(im, str) and len(im) > 64:
decode_base64_to_image_file(im, p)
return paths
def localize_df(data, dname, nproc=32):
assert 'image' in data
indices = list(data['index'])
indices_str = [str(x) for x in indices]
images = list(data['image'])
image_map = {x: y for x, y in zip(indices_str, images)}
root = LMUDataRoot()
root = osp.join(root, 'images', dname)
os.makedirs(root, exist_ok=True)
if 'image_path' in data:
img_paths = list(data['image_path'])
else:
img_paths = []
for i in indices_str:
if len(image_map[i]) <= 64:
idx = image_map[i]
assert idx in image_map and len(image_map[idx]) > 64
img_paths.append(f'{idx}.jpg')
else:
img_paths.append(f'{i}.jpg')
tups = [(root, im, p) for p, im in zip(img_paths, images)]
pool = mp.Pool(32)
ret = pool.map(decode_img_omni, tups)
pool.close()
data.pop('image')
if 'image_path' not in data:
data['image_path'] = [x[0] if len(x) == 1 else x for x in ret]
return data
def LMUDataRoot():
if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
return os.environ['LMUData']
home = osp.expanduser('~')
root = osp.join(home, 'LMUData')
os.makedirs(root, exist_ok=True)
return root
def MMBenchOfficialServer(dataset_name):
root = LMUDataRoot()
if dataset_name in ['MMBench', 'MMBench_V11', 'MMBench_CN', 'MMBench_CN_V11']:
ans_file = f'{root}/{dataset_name}.tsv'
if osp.exists(ans_file):
data = load(ans_file)
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
return True
if dataset_name in ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'MMBench_TEST_EN_V11', 'MMBench_TEST_CN_V11']:
ans_file1 = f'{root}/{dataset_name}.tsv'
mapp = {
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_CN': 'MMBench_CN',
'MMBench_TEST_EN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11',
}
ans_file2 = f'{root}/{mapp[dataset_name]}.tsv'
for f in [ans_file1, ans_file2]:
if osp.exists(f):
data = load(f)
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
return True
return False
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
return {'real': obj.real, 'imag': obj.imag}
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
elif isinstance(obj, (np.bool_)):
return bool(obj)
elif isinstance(obj, (np.void)):
return None
return json.JSONEncoder.default(self, obj)
# LOAD & DUMP
def dump(data, f, **kwargs):
def dump_pkl(data, pth, **kwargs):
pickle.dump(data, open(pth, 'wb'))
def dump_json(data, pth, **kwargs):
json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
def dump_jsonl(data, f, **kwargs):
lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
with open(f, 'w', encoding='utf8') as fout:
fout.write('\n'.join(lines))
def dump_xlsx(data, f, **kwargs):
data.to_excel(f, index=False, engine='xlsxwriter')
def dump_csv(data, f, quoting=csv.QUOTE_ALL):
data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
suffix = f.split('.')[-1]
return handlers[suffix](data, f, **kwargs)
def load(f, fmt=None):
def load_pkl(pth):
return pickle.load(open(pth, 'rb'))
def load_json(pth):
return json.load(open(pth, 'r', encoding='utf-8'))
def load_jsonl(f):
lines = open(f, encoding='utf-8').readlines()
lines = [x.strip() for x in lines]
if lines[-1] == '':
lines = lines[:-1]
data = [json.loads(x) for x in lines]
return data
def load_xlsx(f):
return pd.read_excel(f)
def load_csv(f):
return pd.read_csv(f)
def load_tsv(f):
return pd.read_csv(f, sep='\t')
handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
if fmt is not None:
return handlers[fmt](f)
suffix = f.split('.')[-1]
return handlers[suffix](f)
def download_file(url, filename=None):
import urllib.request
from tqdm import tqdm
class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
if filename is None:
filename = url.split('/')[-1]
try:
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
except:
# Handle Failed Downloads from huggingface.co
if 'huggingface.co' in url:
url_new = url.replace('huggingface.co', 'hf-mirror.com')
try:
download_file(url_new, filename)
return filename
except:
raise Exception(f'Failed to download {url}')
else:
raise Exception(f'Failed to download {url}')
return filename
def ls(dirname='.', match=[], mode='all', level=1):
if isinstance(level, str):
assert '+' in level
level = int(level[:-1])
res = []
for i in range(1, level + 1):
res.extend(ls(dirname, match=match, mode='file', level=i))
return res
if dirname == '.':
ans = os.listdir(dirname)
else:
ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
assert mode in ['all', 'dir', 'file']
assert level >= 1 and isinstance(level, int)
if level == 1:
if isinstance(match, str):
match = [match]
for m in match:
if len(m) == 0:
continue
if m[0] != '!':
ans = [x for x in ans if m in x]
else:
ans = [x for x in ans if m[1:] not in x]
if mode == 'dir':
ans = [x for x in ans if osp.isdir(x)]
elif mode == 'file':
ans = [x for x in ans if not osp.isdir(x)]
return ans
else:
dirs = [x for x in ans if osp.isdir(x)]
res = []
for d in dirs:
res.extend(ls(d, match=match, mode=mode, level=level - 1))
return res
def mrlines(fname, sp='\n'):
f = open(fname).read().split(sp)
while f != [] and f[-1] == '':
f = f[:-1]
return f
def mwlines(lines, fname):
with open(fname, 'w') as fout:
fout.write('\n'.join(lines))
def md5(s):
hash = hashlib.new('md5')
if osp.exists(s):
with open(s, 'rb') as f:
for chunk in iter(lambda: f.read(2**20), b''):
hash.update(chunk)
else:
hash.update(s.encode('utf-8'))
return str(hash.hexdigest())
def last_modified(pth):
stamp = osp.getmtime(pth)
m_ti = time.ctime(stamp)
t_obj = time.strptime(m_ti)
t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
return t
def parse_file(s):
if osp.exists(s) and s != '.':
assert osp.isfile(s)
suffix = osp.splitext(s)[1].lower()
mime = mimetypes.types_map.get(suffix, 'unknown')
return (mime, s)
elif validators.url(s):
suffix = osp.splitext(s)[1].lower()
if suffix in mimetypes.types_map:
mime = mimetypes.types_map[suffix]
dname = osp.join(LMUDataRoot(), 'files')
os.makedirs(dname, exist_ok=True)
tgt = osp.join(dname, md5(s) + suffix)
download_file(s, tgt)
return (mime, tgt)
else:
return ('url', s)
else:
return (None, s)
def file_size(f, unit='GB'):
stats = os.stat(f)
div_map = {
'GB': 2 ** 30,
'MB': 2 ** 20,
'KB': 2 ** 10,
}
return stats.st_size / div_map[unit]
def parquet_to_tsv(file_path):
data = pd.read_parquet(file_path)
pth = '/'.join(file_path.split('/')[:-1])
data_name = file_path.split('/')[-1].split('.')[0]
data.to_csv(osp.join(pth, f'{data_name}.tsv'), sep='\t', index=False)
import logging
logger_initialized = {}
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
except ImportError:
rank = 0
if rank == 0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
# flake8: noqa: F401, F403
import abc
import argparse
import csv
import multiprocessing as mp
import os
import os.path as osp
import copy as cp
import random as rd
import requests
import shutil
import subprocess
import warnings
import logging
import pandas as pd
from collections import OrderedDict, defaultdict
from multiprocessing import Pool, current_process
from tqdm import tqdm
import datetime
import matplotlib.pyplot as plt
from tabulate import tabulate
from json import JSONDecoder
from huggingface_hub import scan_cache_dir
from sty import fg, bg, ef, rs
def process_punctuation(inText):
import re
outText = inText
punct = [
';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!'
]
commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
for p in punct:
if (p + ' ' in inText or ' ' + p in inText) or (re.search(
commaStrip, inText) is not None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = periodStrip.sub('', outText, re.UNICODE)
return outText
def h2r(value):
if value[0] == '#':
value = value[1:]
assert len(value) == 6
return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
def r2h(rgb):
return '#%02x%02x%02x' % rgb
def colored(s, color):
if isinstance(color, str):
if hasattr(fg, color):
return getattr(fg, color) + s + fg.rs
color = h2r(color)
return fg(*color) + s + fg.rs
def istype(s, type):
if isinstance(s, type):
return True
try:
return isinstance(eval(s), type)
except Exception as _:
return False
def bincount(lst):
bins = defaultdict(lambda: 0)
for item in lst:
bins[item] += 1
return bins
def get_cache_path(repo_id, branch=None):
hf_cache_info = scan_cache_dir()
repos = list(hf_cache_info.repos)
repo = None
for r in repos:
if r.repo_id == repo_id:
repo = r
break
if repo is None:
return None
revs = list(repo.revisions)
if branch is not None:
revs = [r for r in revs if r.refs == frozenset({branch})]
rev2keep, last_modified = None, 0
for rev in revs:
if rev.last_modified > last_modified:
rev2keep, last_modified = rev, rev.last_modified
if rev2keep is None:
return None
return str(rev2keep.snapshot_path)
def proxy_set(s):
import os
for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
os.environ[key] = s
def get_rank_and_world_size():
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
return rank, world_size
def splitlen(s, sym='/'):
return len(s.split(sym))
def listinstr(lst, s):
assert isinstance(lst, list)
for item in lst:
if item in s:
return True
return False
def d2df(D):
return pd.DataFrame({x: [D[x]] for x in D})
def cn_string(s):
import re
if re.search(u'[\u4e00-\u9fff]', s):
return True
return False
try:
import decord
except ImportError:
pass
def timestr(second=True, minute=False):
s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
if second:
return s
elif minute:
return s[:-2]
else:
return s[:-4]
def dict_merge(dct, merge_dct):
for k, _ in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
def youtube_dl(idx):
cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
os.system(cmd)
def run_command(cmd):
if isinstance(cmd, str):
cmd = cmd.split()
return subprocess.check_output(cmd).decode()
def load_env():
logger = logging.getLogger('LOAD_ENV')
try:
import vlmeval
except ImportError:
logger.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
return
pth = osp.realpath(vlmeval.__path__[0])
pth = osp.join(pth, '../.env')
pth = osp.realpath(pth)
if not osp.exists(pth):
logger.error(f'Did not detect the .env file at {pth}, failed to load. ')
return
from dotenv import dotenv_values
values = dotenv_values(pth)
for k, v in values.items():
if v is not None and len(v):
os.environ[k] = v
logger.info(f'API Keys successfully loaded from {pth}')
def pip_install_robust(package):
import sys
retry = 3
while retry > 0:
try:
package_base = package.split('=')[0]
module = __import__(package)
return True
except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
retry -= 1
return False
def version_cmp(v1, v2, op='eq'):
from packaging import version
import operator
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
def toliststr(s):
if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
return [str(x) for x in eval(s)]
elif isinstance(s, str):
return [s]
elif isinstance(s, list):
return [str(x) for x in s]
raise NotImplementedError
def extract_json_objects(text, decoder=JSONDecoder()):
pos = 0
while True:
match = text.find('{', pos)
if match == -1: break
try:
result, index = decoder.raw_decode(text[match:])
yield result
pos = match + index
except ValueError:
pos = match + 1
import os
import io
import pandas as pd
import numpy as np
import string
from uuid import uuid4
import os.path as osp
import base64
from PIL import Image
import sys
Image.MAX_IMAGE_PIXELS = 1e9
def rescale_img(img, tgt=None):
assert isinstance(tgt, tuple) and -1 in tgt
w, h = img.size
if tgt[0] != -1:
new_w, new_h = tgt[0], int(tgt[0] / w * h)
elif tgt[1] != -1:
new_w, new_h = int(tgt[1] / h * w), tgt[1]
img = img.resize((new_w, new_h))
return img
def concat_images_vlmeval(images, target_size=-1, mode='h', return_image=False):
from .file import md5
ims = [Image.open(im) for im in images]
if target_size != -1:
ims = [
rescale_img(im, (-1, target_size) if mode == 'h' else (target_size, -1))
for im in ims
]
ws, hs = [x.width for x in ims], [x.height for x in ims]
if mode == 'h':
new_w, new_h = sum(ws), max(hs)
dst = Image.new('RGB', (new_w, new_h))
for i, im in enumerate(ims):
dst.paste(im, (sum(ws[:i]), 0))
elif mode == 'v':
new_w, new_h = max(ws), sum(hs)
dst = Image.new('RGB', (new_w, new_h))
for i, im in enumerate(ims):
dst.paste(im, (sum(ws[:i], 0)))
if return_image:
return dst
else:
_str = '\n'.join(images)
str_md5 = md5(_str)
tgt = osp.join('/tmp', str_md5 + '.jpg')
dst.save(tgt)
return tgt
def mmqa_display(question, target_size=512):
question = {k.lower(): v for k, v in question.items()}
keys = list(question.keys())
keys = [k for k in keys if k not in ['index', 'image']]
images = question['image']
if isinstance(images, str):
images = [images]
idx = question.pop('index', 'XXX')
print(f'INDEX: {idx}')
for im in images:
image = decode_base64_to_image(im, target_size=target_size)
display(image) # noqa: F821
for k in keys:
try:
if not pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
except ValueError:
if False in pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
def encode_image_to_base64(img, target_size=-1):
# if target_size == -1, will not do resizing
# else, will set the max_size ot (target_size, target_size)
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
if target_size > 0:
img.thumbnail((target_size, target_size))
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG')
image_data = img_buffer.getvalue()
ret = base64.b64encode(image_data).decode('utf-8')
return ret
def encode_image_file_to_base64(image_path, target_size=-1):
image = Image.open(image_path)
return encode_image_to_base64(image, target_size=target_size)
def decode_base64_to_image(base64_string, target_size=-1):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
if image.mode in ('RGBA', 'P'):
image = image.convert('RGB')
if target_size > 0:
image.thumbnail((target_size, target_size))
return image
def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
image = decode_base64_to_image(base64_string, target_size=target_size)
image.save(image_path)
def build_option_str(option_dict):
s = 'There are several options: \n'
for c, content in option_dict.items():
if not pd.isna(content):
s += f'{c}. {content}\n'
return s
def isimg(s):
return osp.exists(s) or s.startswith('http')
def read_ok(img_path):
if not osp.exists(img_path):
return False
try:
im = Image.open(img_path)
assert im.size[0] > 0 and im.size[1] > 0
return True
except:
return False
def gpt_key_set():
openai_key = os.environ.get('OPENAI_API_KEY', None)
return isinstance(openai_key, str) and openai_key.startswith('sk-')
def apiok(wrapper):
s = wrapper.generate('Hello!')
return wrapper.fail_msg not in s
def circular_pred(df, extract_func=None):
if extract_func is None:
extract_func = lambda x: x # noqa: E731
df = df.sort_values('index')
from vlmeval.utils import can_infer_option
shift = int(1e6)
choices = [extract_func(x) for x in df['prediction']]
pred_map = {i: c for i, c in zip(df['index'], choices)}
flag_map = {i: True for i in pred_map if i < 1e6}
valid_map = {i: True for i in pred_map if i < 1e6}
for i in df['index']:
if i >= shift and pred_map[i] and pred_map[i - shift]:
if pred_map[i] not in list(
string.ascii_uppercase
) or pred_map[ # noqa: W504
i - shift
] not in list(
string.ascii_uppercase
):
valid_map[i % shift] = False
continue
if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
continue
else:
flag_map[i % shift] = False
flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
flags = list(flag_map.values())
return np.mean(flags)
import sys
from vlmeval.config import *
from vlmeval.smp import *
# Define valid modes
MODES = ('dlist', 'mlist', 'missing', 'circular', 'localize', 'check', 'run', 'eval')
CLI_HELP_MSG = \
f"""
Arguments received: {str(['vlmutil'] + sys.argv[1:])}. vlmutil commands use the following syntax:
vlmutil MODE MODE_ARGS
Where MODE (required) is one of {MODES}
MODE_ARG (optional) is the argument for specific mode
Some usages for xtuner commands: (See more by using -h for specific command!)
1. List all the dataset by levels: l1, l2, l3, etc.:
vlmutil dlist [l1/l2/l3/...]
2. List all the models by categories: 4.33.0, 4.37.0, api, etc.:
vlmutil mlist 4.33.0 [all/small/large]
3. Report missing results:
vlmutil missing [l1/l2/l3/...]
4. Create circular questions (only for multiple-choice questions with no more than 4 choices):
vlmutil circular input.tsv
5. Create a localized version of the dataset (for very large tsv files):
vlmutil localize input.tsv
6. Check the validity of a model:
vlmutil check [model_name/model_series]
7. Run evaluation for missing results:
vlmutil run l2 hf
8. Evaluate data file:
vlmutil eval [dataset_name] [prediction_file]
GitHub: https://github.com/open-compass/VLMEvalKit
""" # noqa: E501
dataset_levels = {
'l1': [
('MMVet', 'gpt-4-turbo_score.csv'), ('MMMU_DEV_VAL', 'acc.csv'),
('MathVista_MINI', 'gpt-4-turbo_score.csv'), ('HallusionBench', 'score.csv'),
('OCRBench', 'score.json'), ('AI2D_TEST', 'acc.csv'), ('MMStar', 'acc.csv'),
('MMBench_V11', 'acc.csv'), ('MMBench_CN_V11', 'acc.csv')
],
'l2': [
('MME', 'score.csv'), ('LLaVABench', 'score.csv'), ('RealWorldQA', 'acc.csv'),
('MMBench', 'acc.csv'), ('MMBench_CN', 'acc.csv'), ('CCBench', 'acc.csv'),
('SEEDBench_IMG', 'acc.csv'), ('COCO_VAL', 'score.json'), ('POPE', 'score.csv'),
('ScienceQA_VAL', 'acc.csv'), ('ScienceQA_TEST', 'acc.csv'), ('MMT-Bench_VAL', 'acc.csv'),
('SEEDBench2_Plus', 'acc.csv'), ('BLINK', 'acc.csv'), ('MTVQA_TEST', 'acc.json'),
('Q-Bench1_VAL', 'acc.csv'), ('A-Bench_VAL', 'acc.csv')
],
'l3': [
('OCRVQA_TESTCORE', 'acc.csv'), ('TextVQA_VAL', 'acc.csv'),
('ChartQA_TEST', 'acc.csv'), ('DocVQA_VAL', 'acc.csv'), ('InfoVQA_VAL', 'acc.csv'),
('SEEDBench2', 'acc.csv')
]
}
dataset_levels['l12'] = dataset_levels['l1'] + dataset_levels['l2']
dataset_levels['l23'] = dataset_levels['l2'] + dataset_levels['l3']
dataset_levels['l123'] = dataset_levels['l12'] + dataset_levels['l3']
models = {
'4.33.0': list(qwen_series) + list(xcomposer_series) + [
'mPLUG-Owl2', 'flamingov2', 'VisualGLM_6b', 'MMAlaya', 'PandaGPT_13B', 'VXVERSE'
] + list(idefics_series) + list(minigpt4_series) + list(instructblip_series),
'4.37.0': [x for x in llava_series if 'next' not in x] + list(internvl_series) + [
'TransCore_M', 'emu2_chat', 'MiniCPM-V', 'MiniCPM-V-2', 'OmniLMM_12B',
'cogvlm-grounding-generalist', 'cogvlm-chat', 'cogvlm2-llama3-chat-19B',
'mPLUG-Owl3'
] + list(xtuner_series) + list(yivl_series) + list(deepseekvl_series) + list(cambrian_series),
'4.36.2': ['Moondream1'],
'4.40.0': [
'idefics2_8b', 'Bunny-llama3-8B', 'MiniCPM-Llama3-V-2_5', '360VL-70B', 'Phi-3-Vision',
] + list(wemm_series),
'4.44.0': ['Moondream2'],
'latest': ['paligemma-3b-mix-448', 'MiniCPM-V-2_6', 'glm-4v-9b'] + [x for x in llava_series if 'next' in x]
+ list(chameleon_series) + list(ovis_series) + list(mantis_series),
'api': list(api_models)
}
# SKIP_MODELS will be skipped in report_missing and run APIs
SKIP_MODELS = [
'MGM_7B', 'GPT4V_HIGH', 'GPT4V', 'flamingov2', 'PandaGPT_13B',
'GeminiProVision', 'Step1V-0701', 'SenseChat-5-Vision',
'llava_v1_7b', 'sharegpt4v_7b', 'sharegpt4v_13b',
'llava-v1.5-7b-xtuner', 'llava-v1.5-13b-xtuner',
'cogvlm-grounding-generalist', 'InternVL-Chat-V1-1',
'InternVL-Chat-V1-2', 'InternVL-Chat-V1-2-Plus', 'RekaCore',
'llava_next_72b', 'llava_next_110b', 'MiniCPM-V', 'sharecaptioner', 'XComposer',
'VisualGLM_6b', 'idefics_9b_instruct', 'idefics_80b_instruct',
'mPLUG-Owl2', 'MMAlaya', 'OmniLMM_12B', 'emu2_chat', 'VXVERSE'
] + list(minigpt4_series) + list(instructblip_series) + list(xtuner_series) + list(chameleon_series) + list(vila_series)
LARGE_MODELS = [
'idefics_80b_instruct', '360VL-70B', 'emu2_chat', 'InternVL2-76B',
]
def completed(m, d, suf):
score_file = f'outputs/{m}/{m}_{d}_{suf}'
if osp.exists(score_file):
return True
if d == 'MMBench':
s1, s2 = f'outputs/{m}/{m}_MMBench_DEV_EN_{suf}', f'outputs/{m}/{m}_MMBench_TEST_EN_{suf}'
return osp.exists(s1) and osp.exists(s2)
elif d == 'MMBench_CN':
s1, s2 = f'outputs/{m}/{m}_MMBench_DEV_CN_{suf}', f'outputs/{m}/{m}_MMBench_TEST_CN_{suf}'
return osp.exists(s1) and osp.exists(s2)
return False
def DLIST(lvl):
lst = [x[0] for x in dataset_levels[lvl]]
return lst
def MLIST(lvl, size='all'):
if lvl == 'all':
from vlmeval.config import supported_VLM
return [x for x in supported_VLM]
model_list = models[lvl]
if size == 'small':
model_list = [m for m in model_list if m not in LARGE_MODELS]
elif size == 'large':
model_list = [m for m in model_list if m in LARGE_MODELS]
return [x[0] for x in model_list]
def MISSING(lvl):
from vlmeval.config import supported_VLM
models = list(supported_VLM)
models = [m for m in models if m not in SKIP_MODELS and osp.exists(osp.join('outputs', m))]
if lvl in dataset_levels.keys():
data_list = dataset_levels[lvl]
else:
data_list = [(D, suff) for (D, suff) in dataset_levels['l123'] if D == lvl]
missing_list = []
for f in models:
for D, suff in data_list:
if not completed(f, D, suff):
missing_list.append((f, D))
return missing_list
def CIRCULAR(inp):
assert inp.endswith('.tsv')
data = load(inp)
OFFSET = 1e6
while max(data['index']) >= OFFSET:
OFFSET *= 10
assert 'E' not in data, 'Currently build_circular only works for up to 4-choice questions'
data_2c = data[pd.isna(data['C'])]
data_3c = data[~pd.isna(data['C']) & pd.isna(data['D'])]
data_4c = data[~pd.isna(data['D'])]
map_2c = [('AB', 'BA')]
map_3c = [('ABC', 'BCA'), ('ABC', 'CAB')]
map_4c = [('ABCD', 'BCDA'), ('ABCD', 'CDAB'), ('ABCD', 'DABC')]
def okn(o, n=4):
ostr = o.replace(',', ' ')
osplits = ostr.split()
if sum([c in osplits for c in string.ascii_uppercase[:n - 1]]) == n - 1:
return False
olower = o.lower()
olower = olower.replace(',', ' ')
olower_splits = olower.split()
if 'all' in olower_splits or 'none' in olower_splits:
return False
return True
yay4, nay4 = [], []
lt4 = len(data_4c)
for i in range(lt4):
if okn(data_4c.iloc[i]['D'], 4):
yay4.append(i)
else:
nay4.append(i)
data_4c_y = data_4c.iloc[yay4]
data_4c_n = data_4c.iloc[nay4]
data_3c = pd.concat([data_4c_n, data_3c])
yay3, nay3 = [], []
lt3 = len(data_3c)
for i in range(lt3):
if okn(data_3c.iloc[i]['C'], 3):
yay3.append(i)
else:
nay3.append(i)
data_3c_y = data_3c.iloc[yay3]
data_3c_n = data_3c.iloc[nay3]
data_2c = pd.concat([data_3c_n, data_2c])
def remap(data_in, tup, off):
off = int(off)
data = data_in.copy()
char_map = {k: v for k, v in zip(*tup)}
idx = data.pop('index')
answer = data.pop('answer')
answer_new = [char_map[x] if x in char_map else x for x in answer]
data['answer'] = answer_new
options = {}
for c in char_map:
options[char_map[c]] = data.pop(c)
for c in options:
data[c] = options[c]
data.pop('image')
data['image'] = idx
idx = [x + off for x in idx]
data['index'] = idx
return data
data_all = pd.concat([
data_2c,
data_3c_y,
data_4c_y,
remap(data_2c, map_2c[0], OFFSET),
remap(data_3c_y, map_3c[0], OFFSET),
remap(data_4c_y, map_4c[0], OFFSET),
remap(data_3c_y, map_3c[1], OFFSET * 2),
remap(data_4c_y, map_4c[1], OFFSET * 2),
remap(data_4c_y, map_4c[2], OFFSET * 3),
])
tgt_file = inp.replace('.tsv', '_CIRC.tsv')
dump(data_all, tgt_file)
print(f'The circularized data is saved to {tgt_file}')
assert osp.exists(tgt_file)
print(f'The MD5 for the circularized data is {md5(tgt_file)}')
PTH = osp.realpath(__file__)
IMAGE_PTH = osp.join(osp.dirname(PTH), '../assets/apple.jpg')
msg1 = [
IMAGE_PTH,
'What is in this image?'
]
msg2 = [
dict(type='image', value=IMAGE_PTH),
dict(type='text', value='What is in this image?')
]
msg3 = [
IMAGE_PTH,
IMAGE_PTH,
'How many apples are there in these images?'
]
msg4 = [
dict(type='image', value=IMAGE_PTH),
dict(type='image', value=IMAGE_PTH),
dict(type='text', value='How many apples are there in these images?')
]
def CHECK(val):
if val in supported_VLM:
model = supported_VLM[val]()
print(f'Model: {val}')
for i, msg in enumerate([msg1, msg2, msg3, msg4]):
if i > 1 and not model.INTERLEAVE:
continue
res = model.generate(msg)
print(f'Test {i + 1}: {res}')
elif val in models:
model_list = models[val]
for m in model_list:
CHECK(m)
def LOCALIZE(fname, new_fname=None):
if new_fname is None:
new_fname = fname.replace('.tsv', '_local.tsv')
base_name = osp.basename(fname)
dname = osp.splitext(base_name)[0]
data = load(fname)
data_new = localize_df(data, dname)
dump(data_new, new_fname)
print(f'The localized version of data file is {new_fname}')
return new_fname
def RUN(lvl, model):
import torch
NGPU = torch.cuda.device_count()
SCRIPT = osp.join(osp.dirname(__file__), '../run.py')
logger = get_logger('Run Missing')
def get_env(name):
assert name in ['433', '437', '440', 'latest']
load_env()
env_key = f'ENV_{name}'
return os.environ.get(env_key, None)
missing = MISSING(lvl)
if model == 'all':
pass
elif model == 'api':
missing = [x for x in missing if x[0] in models['api']]
elif model == 'hf':
missing = [x for x in missing if x[0] not in models['api']]
elif model in models:
missing = [x for x in missing if x[0] in models[missing]]
elif model in supported_VLM:
missing = [x for x in missing if x[0] == model]
else:
warnings.warn(f'Invalid model {model}.')
missing.sort(key=lambda x: x[0])
groups = defaultdict(list)
for m, D in missing:
groups[m].append(D)
for m in groups:
if m in SKIP_MODELS:
continue
for dataset in groups[m]:
logger.info(f'Running {m} on {dataset}')
exe = 'python' if m in LARGE_MODELS or m in models['api'] else 'torchrun'
if m not in models['api']:
env = None
env = 'latest' if m in models['latest'] else env
env = '433' if m in models['4.33.0'] else env
env = '437' if m in models['4.37.0'] else env
env = '440' if m in models['4.40.0'] else env
if env is None:
# Not found, default to latest
env = 'latest'
logger.warning(
f"Model {m} does not have a specific environment configuration. Defaulting to 'latest'.")
pth = get_env(env)
if pth is not None:
exe = osp.join(pth, 'bin', exe)
else:
logger.warning(f'Cannot find the env path {env} for model {m}')
if exe.endswith('torchrun'):
cmd = f'{exe} --nproc-per-node={NGPU} {SCRIPT} --model {m} --data {dataset}'
elif exe.endswith('python'):
cmd = f'{exe} {SCRIPT} --model {m} --data {dataset}'
os.system(cmd)
def EVAL(dataset_name, data_file):
from vlmeval.dataset import build_dataset
logger = get_logger('VLMEvalKit Tool-Eval')
dataset = build_dataset(dataset_name)
# Set the judge kwargs first before evaluation or dumping
judge_kwargs = {'nproc': 4, 'verbose': True}
if dataset.TYPE in ['MCQ', 'Y/N'] or listinstr(['MathVerse'], dataset_name):
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'], dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['MMLongBench', 'MMDU'], dataset_name):
judge_kwargs['model'] = 'gpt-4o'
eval_results = dataset.evaluate(data_file, **judge_kwargs)
if eval_results is not None:
assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame)
logger.info('Evaluation Results:')
if isinstance(eval_results, dict):
logger.info('\n' + json.dumps(eval_results, indent=4))
elif isinstance(eval_results, pd.DataFrame):
if len(eval_results) < len(eval_results.columns):
eval_results = eval_results.T
logger.info('\n' + tabulate(eval_results))
def cli():
logger = get_logger('VLMEvalKit Tools')
args = sys.argv[1:]
if not args: # no arguments passed
logger.info(CLI_HELP_MSG)
return
if args[0].lower() in MODES:
if args[0].lower() == 'dlist':
assert len(args) >= 2
lst = DLIST(args[1])
print(' '.join(lst))
elif args[0].lower() == 'mlist':
assert len(args) >= 2
size = 'all'
if len(args) > 2:
size = args[2].lower()
lst = MLIST(args[1], size)
print('\n'.join(lst))
elif args[0].lower() == 'missing':
assert len(args) >= 2
missing_list = MISSING(args[1])
logger = get_logger('Find Missing')
logger.info(colored(f'Level {args[1]} Missing Results: ', 'red'))
lines = []
for m, D in missing_list:
line = f'Model {m}, Dataset {D}'
logger.info(colored(line, 'red'))
lines.append(line)
mwlines(lines, f'{args[1]}_missing.txt')
elif args[0].lower() == 'circular':
assert len(args) >= 2
CIRCULAR(args[1])
elif args[0].lower() == 'localize':
assert len(args) >= 2
LOCALIZE(args[1])
elif args[0].lower() == 'check':
assert len(args) >= 2
model_list = args[1:]
for m in model_list:
CHECK(m)
elif args[0].lower() == 'run':
assert len(args) >= 2
lvl = args[1]
if len(args) == 2:
model = 'all'
RUN(lvl, model)
else:
for model in args[2:]:
RUN(lvl, model)
elif args[0].lower() == 'eval':
assert len(args) == 3
dataset, data_file = args[1], args[2]
EVAL(dataset, data_file)
else:
logger.error('WARNING: command error!')
logger.info(CLI_HELP_MSG)
return
from .matching_util import can_infer, can_infer_option, can_infer_text
from .mp_util import track_progress_rich
__all__ = [
'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
]
import string
import copy as cp
import os
from ..smp import *
def can_infer_option(answer, choices):
verbose = os.environ.get('VERBOSE', 0)
# Choices is a dictionary
if 'Failed to obtain answer via API' in answer:
return False
reject_to_answer = [
"Sorry, I can't help with images of people yet.",
"I can't process this file.",
"I'm sorry, but without the image provided",
'Cannot determine the answer'
]
for err in reject_to_answer:
if err in answer:
return 'Z'
def count_choice(splits, choices, prefix='', suffix=''):
cnt = 0
for c in choices:
if prefix + c + suffix in splits:
cnt += 1
return cnt
answer_mod = cp.copy(answer)
chars = '.()[],:;!*#{}'
for c in chars:
answer_mod = answer_mod.replace(c, ' ')
splits = [x.strip() for x in answer_mod.split()]
count = count_choice(splits, choices)
if count == 1:
for ch in choices:
if 'A' in splits and len(splits) > 3 and verbose:
logger = get_logger('Evaluation')
logger.info(f'A might be a quantifier in the string: {answer}.')
return False
if ch in splits:
return ch
elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
return 'Z'
return False
def can_infer_text(answer, choices):
answer = answer.lower()
assert isinstance(choices, dict)
for k in choices:
assert k in string.ascii_uppercase
choices[k] = str(choices[k]).lower()
cands = []
for k in choices:
if choices[k] in answer:
cands.append(k)
if len(cands) == 1:
return cands[0]
return False
def can_infer(answer, choices):
answer = str(answer)
copt = can_infer_option(answer, choices)
return copt if copt else can_infer_text(answer, choices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment