Commit bc5ebf0f authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #2167 canceled with stages
from ...smp import *
from .multiple_choice import extract_answer_from_item
from PIL import Image, ImageOps
import torchvision
import random
import numbers
import math
import torch
def get_dimension_rating(data_path):
data = load(data_path)
result_board = {}
for idx, item in data.iterrows():
if item['task_type'] not in result_board:
result_board[item['task_type']] = [0, 0]
result_board[item['task_type']][1] += 1
if item['score']:
result_board[item['task_type']][0] += 1
correct = 0
total = 0
for key, value in result_board.items():
correct += value[0]
total += value[1]
result_board[key].append(f'{value[0] / value[1] * 100 :.2f}%')
result_board['overall'] = [correct, total, f'{correct / total * 100 :.2f}%']
return result_board
def check_ans(pred, gt):
flag = False
pred_list = pred.lower().strip().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().strip().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
return flag
def check_ans_with_model(pred, gt, model, item, dataset_name='MVBench'):
flag = False
pred_list = pred.lower().strip().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().strip().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
elif extract_answer_from_item(model, item, dataset_name)['opt'] == item['answer']:
flag = True
return flag
def check_ans_advanced(pred, gt):
number_table = {
0: 'zero',
1: 'one',
2: 'two',
3: 'three',
4: 'four',
5: 'five',
6: 'six',
7: 'seven',
8: 'eight',
9: 'nine',
}
flag = False
pred_list = pred.lower().strip().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().strip().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
try:
gt_content = number_table[int(gt_content.strip('. \n'))]
print(gt_content)
except:
pass
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
elif gt_content.lower().strip('. \n') in pred.lower().strip('. \n'):
flag = True
return flag
class GroupRandomCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img_group):
w, h = img_group[0].size
th, tw = self.size
out_images = list()
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in img_group:
assert (img.size[0] == w and img.size[1] == h)
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
class MultiGroupRandomCrop(object):
def __init__(self, size, groups=1):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.groups = groups
def __call__(self, img_group):
w, h = img_group[0].size
th, tw = self.size
out_images = list()
for i in range(self.groups):
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in img_group:
assert (img.size[0] == w and img.size[1] == h)
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
class GroupCenterCrop(object):
def __init__(self, size):
self.worker = torchvision.transforms.CenterCrop(size)
def __call__(self, img_group):
return [self.worker(img) for img in img_group]
class GroupRandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __init__(self, is_flow=False):
self.is_flow = is_flow
def __call__(self, img_group, is_flow=False):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
if self.is_flow:
for i in range(0, len(ret), 2):
# invert flow pixel values when flipping
ret[i] = ImageOps.invert(ret[i])
return ret
else:
return img_group
class GroupNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
rep_std = self.std * (tensor.size()[0] // len(self.std))
# TODO: make efficient
for t, m, s in zip(tensor, rep_mean, rep_std):
t.sub_(m).div_(s)
return tensor
class GroupScale(object):
""" Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.worker = torchvision.transforms.Resize(size, interpolation)
def __call__(self, img_group):
return [self.worker(img) for img in img_group]
class GroupOverSample(object):
def __init__(self, crop_size, scale_size=None, flip=True):
self.crop_size = crop_size if not isinstance(
crop_size, int) else (crop_size, crop_size)
if scale_size is not None:
self.scale_worker = GroupScale(scale_size)
else:
self.scale_worker = None
self.flip = flip
def __call__(self, img_group):
if self.scale_worker is not None:
img_group = self.scale_worker(img_group)
image_w, image_h = img_group[0].size
crop_w, crop_h = self.crop_size
offsets = GroupMultiScaleCrop.fill_fix_offset(
False, image_w, image_h, crop_w, crop_h)
oversample_group = list()
for o_w, o_h in offsets:
normal_group = list()
flip_group = list()
for i, img in enumerate(img_group):
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
normal_group.append(crop)
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
if img.mode == 'L' and i % 2 == 0:
flip_group.append(ImageOps.invert(flip_crop))
else:
flip_group.append(flip_crop)
oversample_group.extend(normal_group)
if self.flip:
oversample_group.extend(flip_group)
return oversample_group
class GroupFullResSample(object):
def __init__(self, crop_size, scale_size=None, flip=True):
self.crop_size = crop_size if not isinstance(
crop_size, int) else (crop_size, crop_size)
if scale_size is not None:
self.scale_worker = GroupScale(scale_size)
else:
self.scale_worker = None
self.flip = flip
def __call__(self, img_group):
if self.scale_worker is not None:
img_group = self.scale_worker(img_group)
image_w, image_h = img_group[0].size
crop_w, crop_h = self.crop_size
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4
offsets = list()
offsets.append((0 * w_step, 2 * h_step)) # left
offsets.append((4 * w_step, 2 * h_step)) # right
offsets.append((2 * w_step, 2 * h_step)) # center
oversample_group = list()
for o_w, o_h in offsets:
normal_group = list()
flip_group = list()
for i, img in enumerate(img_group):
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
normal_group.append(crop)
if self.flip:
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
if img.mode == 'L' and i % 2 == 0:
flip_group.append(ImageOps.invert(flip_crop))
else:
flip_group.append(flip_crop)
oversample_group.extend(normal_group)
oversample_group.extend(flip_group)
return oversample_group
class GroupMultiScaleCrop(object):
def __init__(self, input_size, scales=None, max_distort=1,
fix_crop=True, more_fix_crop=True):
self.scales = scales if scales is not None else [1, .875, .75, .66]
self.max_distort = max_distort
self.fix_crop = fix_crop
self.more_fix_crop = more_fix_crop
self.input_size = input_size if not isinstance(input_size, int) else [
input_size, input_size]
self.interpolation = Image.BILINEAR
def __call__(self, img_group):
im_size = img_group[0].size
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
crop_img_group = [
img.crop(
(offset_w,
offset_h,
offset_w + crop_w,
offset_h + crop_h)) for img in img_group]
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
for img in crop_img_group]
return ret_img_group
def _sample_crop_size(self, im_size):
image_w, image_h = im_size[0], im_size[1]
# find a crop size
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in self.scales]
crop_h = [
self.input_size[1] if abs(
x - self.input_size[1]) < 3 else x for x in crop_sizes]
crop_w = [
self.input_size[0] if abs(
x - self.input_size[0]) < 3 else x for x in crop_sizes]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= self.max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not self.fix_crop:
w_offset = random.randint(0, image_w - crop_pair[0])
h_offset = random.randint(0, image_h - crop_pair[1])
else:
w_offset, h_offset = self._sample_fix_offset(
image_w, image_h, crop_pair[0], crop_pair[1])
return crop_pair[0], crop_pair[1], w_offset, h_offset
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
offsets = self.fill_fix_offset(
self.more_fix_crop, image_w, image_h, crop_w, crop_h)
return random.choice(offsets)
@staticmethod
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4
ret = list()
ret.append((0, 0)) # upper left
ret.append((4 * w_step, 0)) # upper right
ret.append((0, 4 * h_step)) # lower left
ret.append((4 * w_step, 4 * h_step)) # lower right
ret.append((2 * w_step, 2 * h_step)) # center
if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
return ret
class GroupRandomSizedCrop(object):
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
This is popularly used to train the Inception networks
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, img_group):
for attempt in range(10):
area = img_group[0].size[0] * img_group[0].size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
x1 = random.randint(0, img_group[0].size[0] - w)
y1 = random.randint(0, img_group[0].size[1] - h)
found = True
break
else:
found = False
x1 = 0
y1 = 0
if found:
out_group = list()
for img in img_group:
img = img.crop((x1, y1, x1 + w, y1 + h))
assert (img.size == (w, h))
out_group.append(
img.resize(
(self.size, self.size), self.interpolation))
return out_group
else:
# Fallback
scale = GroupScale(self.size, interpolation=self.interpolation)
crop = GroupRandomCrop(self.size)
return crop(scale(img_group))
class ConvertDataFormat(object):
def __init__(self, model_type):
self.model_type = model_type
def __call__(self, images):
if self.model_type == '2D':
return images
tc, h, w = images.size()
t = tc // 3
images = images.view(t, 3, h, w)
images = images.permute(1, 0, 2, 3)
return images
class Stack(object):
def __init__(self, roll=False):
self.roll = roll
def __call__(self, img_group):
if img_group[0].mode == 'L':
return np.concatenate([np.expand_dims(x, 2)
for x in img_group], axis=2)
elif img_group[0].mode == 'RGB':
if self.roll:
return np.concatenate([np.array(x)[:, :, ::-1]
for x in img_group], axis=2)
else:
# print(np.concatenate(img_group, axis=2).shape)
# print(img_group[0].shape)
return np.concatenate(img_group, axis=2)
class ToTorchFormatTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
def __init__(self, div=True):
self.div = div
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
else:
# handle PIL Image
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(
pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255) if self.div else img.float()
class IdentityTransform(object):
def __call__(self, data):
return data
import re
def extract_answer(output_string, task_type="yes_no"):
"""
Extracts the answer from the output string based on the task type.
Parameters:
output_string (str): The output string.
task_type (str): The type of task. Must be either "yes_no" or "multiple_choice".
Returns:
int:
1 if "yes" or "A"
0 if "no" or "B"
-1 if no relevant answer is found.
Raises a ValueError if an unsupported task_type is provided.
"""
def find_word_position(string, word):
pattern = r'\b' + re.escape(word) + r'\b'
match = re.search(pattern, string, re.IGNORECASE)
if match:
return match.start()
return -1
if task_type not in ["yes_no", "multiple_choice"]:
raise ValueError(f"Task type {task_type} not supported. Must be 'yes_no' or 'multiple_choice'.")
if task_type == "yes_no":
position_yes_and_a = find_word_position(output_string, "yes")
position_no_and_b = find_word_position(output_string, "no")
elif task_type == "multiple_choice":
position_yes_and_a = find_word_position(output_string, "A")
position_no_and_b = find_word_position(output_string, "B")
if position_yes_and_a == -1 and position_no_and_b == -1:
print(f"No answer found in the output string: {output_string}.")
return -1
elif position_yes_and_a != -1 and position_no_and_b != -1:
return 1 if position_yes_and_a < position_no_and_b else 0
else:
return 0 if position_yes_and_a == -1 else 1
def get_scores(scores):
"""
Calculate various scores based on the given results.
Args:
scores (dict or list): A dictionary or list containing results where each result can be:
- dict: {id: {"q0_i0": 1 or 0, "q0_i1": 1 or 0, "q1_i0": 1 or 0, "q1_i1": 1 or 0}, ...}
- list: [[q0_i0 (1 or 0), q0_i1 (1 or 0), q1_i0 (1 or 0), q1_i1 (1 or 0)], ...]
The keys "q0_i0", "q0_i1", "q1_i0", "q1_i1" represent combinations of questions and images:
- "q0_i0" means question_0 on image_0
- "q0_i1" means question_0 on image_1
- "q1_i0" means question_1 on image_0
- "q1_i1" means question_1 on image_1
Returns:
dict: A dictionary containing the calculated scores:
- 'Q_Acc': Average question score
- 'I_Acc': Average image score
- 'Acc': Average binary VQA score
- 'G_Acc': Average group score
"""
Q_Acc = 0.0
I_Acc = 0.0
Acc = 0.0
G_Acc = 0.0
num_samples = len(scores)
def calculate_image_score(result):
image_correct = 0
if isinstance(result, dict):
if result["q0_i0"] == 1.0 and result["q1_i0"] == 0.0:
image_correct += 1
if result["q1_i1"] == 1.0 and result["q0_i1"] == 0.0:
image_correct += 1
elif isinstance(result, list):
if result[0] == 1.0 and result[2] == 0.0:
image_correct += 1
if result[3] == 1.0 and result[1] == 0.0:
image_correct += 1
return image_correct
def calculate_question_score(result):
text_correct = 0
if isinstance(result, dict):
if result["q0_i0"] == 1.0 and result["q0_i1"] == 0.0:
text_correct += 1
if result["q1_i1"] == 1.0 and result["q1_i0"] == 0.0:
text_correct += 1
else:
if result[0] == 1.0 and result[1] == 0.0:
text_correct += 1
if result[3] == 1.0 and result[2] == 0.0:
text_correct += 1
return text_correct
def calculate_binary_score(result):
binary_score_correct = 0
if isinstance(result, dict):
binary_score_correct += 1 if result["q0_i0"] == 1.0 else 0
binary_score_correct += 1 if result["q0_i1"] == 0.0 else 0
binary_score_correct += 1 if result["q1_i0"] == 0.0 else 0
binary_score_correct += 1 if result["q1_i1"] == 1.0 else 0
else:
binary_score_correct += 1 if result[0] == 1.0 else 0
binary_score_correct += 1 if result[1] == 0.0 else 0
binary_score_correct += 1 if result[2] == 0.0 else 0
binary_score_correct += 1 if result[3] == 1.0 else 0
return binary_score_correct
def calculate_group(result):
group_correct = 0
if calculate_question_score(result) == 2 and calculate_image_score(result) == 2:
group_correct += 1
return group_correct
if isinstance(scores, dict):
for _, result in scores.items():
Q_Acc += calculate_question_score(result)
I_Acc += calculate_image_score(result)
Acc += calculate_binary_score(result)
G_Acc += calculate_group(result)
else:
for result in scores:
Q_Acc += calculate_question_score(result)
I_Acc += calculate_image_score(result)
Acc += calculate_binary_score(result)
G_Acc += calculate_group(result)
results = {
'Q_Acc': Q_Acc / float(num_samples * 2),
'I_Acc': I_Acc / float(num_samples * 2),
'Acc': Acc / float(num_samples * 4),
'G_Acc': G_Acc / num_samples
}
return results
from ...smp import *
def OCRBench_eval(eval_file):
OCRBench_score = {
'Regular Text Recognition': 0,
'Irregular Text Recognition': 0,
'Artistic Text Recognition': 0,
'Handwriting Recognition': 0,
'Digit String Recognition': 0,
'Non-Semantic Text Recognition': 0,
'Scene Text-centric VQA': 0,
'Doc-oriented VQA': 0,
'Key Information Extraction': 0,
'Handwritten Mathematical Expression Recognition': 0
}
logger = get_logger('Evaluation')
data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
for i in tqdm(range(len(lines))):
line = lines[i]
predict = str(line['prediction'])
answers = eval(line['answer'])
category = line['category']
if category == 'Handwritten Mathematical Expression Recognition':
for j in range(len(answers)):
answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
predict = predict.strip().replace('\n', ' ').replace(' ', '')
if answer in predict:
OCRBench_score[category] += 1
break
else:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace('\n', ' ')
predict = predict.lower().strip().replace('\n', ' ')
if answer in predict:
OCRBench_score[category] += 1
break
final_score_dict = {}
final_score_dict['Text Recognition'] = (
OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
+ OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
+ OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
)
final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
final_score_dict['Handwritten Mathematical Expression Recognition'] = \
OCRBench_score['Handwritten Mathematical Expression Recognition']
final_score_dict['Final Score'] = (
final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
+ final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
+ final_score_dict['Handwritten Mathematical Expression Recognition']
)
final_score_dict['Final Score Norm'] = float(final_score_dict['Final Score']) / 10
score_pth = eval_file.replace('.xlsx', '_score.json')
dump(final_score_dict, score_pth)
logger.info(f'OCRBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
logger.info('Score: ')
for key, value in final_score_dict.items():
logger.info('{}:{}'.format(key, value))
import re
import json
from math import isclose
import sympy as sp
from sympy import simplify, Eq, sympify, evalf, Pow
from sympy.parsing.latex import parse_latex
import antlr4
from decimal import Decimal, getcontext
from fractions import Fraction
import sys
import math
chinese_answer_type_dict = {
'Numerical': '数值',
'Expression': '表达式',
'Equation': '方程',
'Interval': '区间'
}
english_answer_type_dict = {
'Numerical': 'a numerical value',
'Expression': 'an expression',
'Equation': 'an equation',
'Interval': 'an interval'
}
def get_single_answer_type_text(answer_type, is_chinese):
if '-' in answer_type: # No need now
answer_type = answer_type[:answer_type.find('-')]
for t in ['Numerical', 'Expression', 'Equation', 'Interval']:
if t in answer_type:
if is_chinese:
return chinese_answer_type_dict[t]
else:
return english_answer_type_dict[t]
exit(f'Error parsing answer type {answer_type}!')
def get_answer_type_text(answer_type, is_chinese, multiple_answer):
# 'Tuple' has various meanings in different context, such as position or values of a series of variable,
# so it may lead to confusion to directly use 'tuple' in the prompt.
if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type):
full_answer_text = ''
else:
if not multiple_answer:
answer_text = get_single_answer_type_text(answer_type, is_chinese)
if is_chinese:
full_answer_text = f',答案类型为{answer_text}'
else:
full_answer_text = f"The answer of The problem should be {answer_text}. "
else:
if ',' not in answer_type: # Same answer type for all answers
answer_text = get_single_answer_type_text(answer_type, is_chinese)
if is_chinese:
full_answer_text = f',题目有多个答案,答案类型均为{answer_text}'
else:
full_answer_text = f'The problem has multiple answers, each of them should be {answer_text}. '
else:
answer_types = answer_type.split(',')
answer_types = [get_single_answer_type_text(t, is_chinese) for t in answer_types]
if len(set(answer_types)) == 1:
answer_text = answer_types[0]
if is_chinese:
full_answer_text = f',题目有多个答案,答案类型均为{answer_text}'
else:
full_answer_text = f'The problem has multiple answers, each of them should be {answer_text}. '
else:
if is_chinese:
answer_text = '、'.join(answer_types)
full_answer_text = f',题目有多个答案,答案类型分别为{answer_text}'
else:
answer_text = ', '.join(answer_types)
full_answer_text = (
f'The problem has multiple answers, with the answers in order being {answer_text}. '
)
return full_answer_text
def make_input(prompt, question_content):
# diversified based on the vllm, which is not implemented temporarily
input = prompt + '\n' + question_content
return input
sys.set_int_max_str_digits(1000000)
# 设置decimal的精度
getcontext().prec = 50
class MathJudger:
def __init__(self):
self.special_signal_map = {
"\\left": "",
"\\right": "",
"∶": ":",
",": ",",
"$": "",
"\\approx": "=",
"\\simeq": "=",
"\\sim": "=",
"^\\prime": "'",
"^{\\prime}": "'",
"^\\circ": "",
"%": "",
}
self.pi = parse_latex("\\pi")
self.precision = 1e-8
def split_by_comma(self, expr: str):
in_bracket_num = 0
splitted_expr = []
start_idx = 0
for i, char in enumerate(expr):
if char == "(" or char == "[":
in_bracket_num += 1
elif char == ")" or char == "]":
in_bracket_num -= 1
elif char == "," and in_bracket_num == 0:
splitted_expr.append(expr[start_idx:i].strip())
start_idx = i + 1
if start_idx < len(expr):
splitted_expr.append(expr[start_idx:].strip())
return splitted_expr
def trans_plus_minus_sign(self, expr_list: list):
new_expr_list = []
for expr in expr_list:
if "\\pm" in expr:
new_expr_list.append(expr.replace("\\pm", "+"))
new_expr_list.append(expr.replace("\\pm", "-"))
else:
new_expr_list.append(expr)
return new_expr_list
def judge(self, expression1, expression2, precision=1e-8):
# (默认 expression1 为 Ground_Truth)
precision = precision if isinstance(precision, list) else [precision]
try:
expression1, expression2 = self.preprocess(expression1, expression2)
except:
return False
if expression1 == expression2:
# print("原生相等")
return True
# 去除字符串中的中文字符,因为上面已经判断过了类似回答为"能"或"不能"的含有中文字符的回答情况
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
expression1 = self.split_by_comma(expression1)
expression2 = self.split_by_comma(expression2)
temp_list1 = self.trans_plus_minus_sign(expression1)
temp_list2 = self.trans_plus_minus_sign(expression2)
# 设计误差值列表
if len(precision) <= 1:
precision = precision * len(temp_list1)
if len(temp_list1) != len(temp_list2):
return False
# 判断两个列表中的元素是否可以两两配对,并且两两相等,由此支持多个回答的比较
idx = -1
while len(temp_list1) != 0:
idx = (idx + 1) % len(temp_list1)
item1 = temp_list1[idx]
self.precision = precision[idx]
# print(self.precision)
for item2 in temp_list2:
if self.is_equal(item1, item2):
temp_list1.remove(item1)
temp_list2.remove(item2)
precision.remove(self.precision)
break
else:
# If we didn't break from the inner loop, it means no match was found
return False
# If all elements are matched and removed, the lists can be paired
return True
def is_interval(self, epr):
return epr.startswith(("(", "[")) and epr.endswith((")", "]"))
# 在进行数值计算前,需要将sympy中的pi符号替换为pi的近似数值
# def sympy_sub_pi(self, expression_sympy):
# return expression_sympy.subs(self.pi, math.pi)
# 默认第一个表达式是 ground_truth
def is_equal(self, expression1, expression2):
if expression1 == expression2 and expression1 != "" and expression2 != "":
# print("原生等价")
return True
# 先判断是否是两个区间,是的话进行判断相等,不相等则返回 False
if self.is_interval(expression1) and self.is_interval(expression2):
try:
if self.interval_equal(expression1, expression2):
# print("区间等价")
return True
except:
return False
# 再判断是否在数值上相等
try:
if self.numerical_equal(expression1, expression2):
# print("数值等价")
return True
except:
pass
# 再判断是否是表达式相等
try:
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
# print("表达式等价")
return True
except:
pass
# 再判断是否是等式相等
try:
if self.equation_equal(expression1, expression2):
# print("等式等价")
return True
except:
pass
return False
# 判断两个数值在误差允许范围内是否相等
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
"""
(默认 expression1 为 Ground_Truth)
函数: 判读两个数值是否在误差允许范围内相等
步骤1: 将可能出现的百分号的情况包含进来
步骤2: 使用 math.isclose 函数判断是否相等
"""
reference = float(expression1)
prediction = float(expression2)
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
# if isclose(item, prediction, abs_tol=self.precision, rel_tol=0):
if abs(item - prediction) <= self.precision * 1.01:
return True
return False
def expression_equal(self, exp1, exp2):
"""
(默认 expression1 为 Ground_Truth)
函数: 判断两个表达式是否在数学意义上等价
步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1"
步骤2: 使用 sympy 库进行等价判断
"""
# 只提取等号右边的表达式,一般左边是所求的量
def extract_expression(expression):
if "=" in expression:
expression = expression.split("=")[1]
return expression.strip()
exp1 = extract_expression(exp1)
exp2 = extract_expression(exp2)
exp_too_long = len(exp1) > 300 or len(exp2) > 300
# 将表达式转换为 sympy 中能够进行处理的格式
expr1_sym = sympify(parse_latex(exp1))
expr2_sym = sympify(parse_latex(exp2))
if expr1_sym == expr2_sym:
return True
else:
expr1_sym = self.sympy_sub_pi(expr1_sym)
expr2_sym = self.sympy_sub_pi(expr2_sym)
# 如果输入的表达式可以计算出具体数值的话,则将其进行数值计算的比较
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (
not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
return False
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
try:
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
print(
"These two number can not be calculated by current computer for: "
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\""
)
return False
if exp_too_long:
print(f'Expression {exp1} or {exp2} is too long to compute. ')
return False
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
return True
else:
return False
except:
return False
elif exp_too_long:
print(f'Expression {exp1} or {exp2} is too long to compute. ')
return False
else:
try:
simplified_expr = simplify(expr1_sym - expr2_sym)
num_value = simplified_expr.evalf()
return abs(num_value) < 1e-3
except:
return False
def equation_equal(self, expression1, expression2):
"""
(默认 expression1 为 Ground_Truth)
函数: 判断两个方程是否在数学意义上等价
步骤1: 将一个方程/等式化简为标准方程, 即等式的右边严格等于0, 接下来只需要判断两个等式的左边是否"等价"
步骤2: 使用 sympy 库计算两个等式左边的商, 如果这个商或者这个商的倒数为整数, 那么数学意义上我们可以推导出这两个方程等价👌
"""
# 将等式的右边都移到左边,并返回一个 sympy 格式的表达式
def simplify_equation(latex_eq):
# 分割等式的左边和右边
lhs, rhs = latex_eq.split('=')
# 使用 parse_latex 解析 LaTeX 表达式
lhs_expr = parse_latex(lhs)
rhs_expr = parse_latex(rhs)
# 创建等式对象
equation = Eq(lhs_expr, rhs_expr)
# 化简等式:将等式右边移到左边
simplified_eq = simplify(equation.lhs - equation.rhs)
return simplified_eq
expr1_sym = simplify_equation(expression1)
expr2_sym = simplify_equation(expression2)
division_result_1 = simplify(expr1_sym / expr2_sym)
division_result_2 = simplify(expr2_sym / expr1_sym)
# 如果两个方程转换后的式子相除为整数 且非零,则根据推导可知这两个方程等价
if (division_result_1.is_Integer and division_result_1 != 0) or (
division_result_2.is_Integer and division_result_2 != 0):
return True
else:
return False
def interval_equal(self, expression1, expression2):
# 函数: 判断两个区间是否在数学意义上等价
# 步骤1: 简化区间的表达式, 去除无关的符号比如"\left", "\right", 同时将可能出现的"x \in"删去
# 步骤2: 对比两个区间的左右符号、中间出现的数学表达式等是否一致
def compare_two_interval(inter1, inter2):
# 首先比较两边的括号是否一致,一致的话再进行下一步比较
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
return False
inter1 = inter1.strip('[]()')
inter2 = inter2.strip('[]()')
# 分割区间的左右部分
items_1 = inter1.split(',')
items_2 = inter2.split(',')
for item_1, item_2 in zip(items_1, items_2):
if not self.expression_equal(item_1, item_2):
return False
return True
interval1 = expression1
interval2 = expression2
if interval1 == interval2:
return True
else:
inter_list1 = interval1.split("\\cup")
inter_list2 = interval2.split("\\cup")
if len(inter_list1) != len(inter_list2):
return False
else:
for inter1, inter2 in zip(inter_list1, inter_list2):
if not compare_two_interval(inter1, inter2):
return False
return True
def preprocess(self, expression1, expression2):
# 尝试捕获box中的内容,如果有多个则以逗号相连返回,如果一个都没有,则报错
def extract_boxed_content(latex_str):
# 查找所有的 \boxed{...} 结构
boxed_matches = re.finditer(r'\\boxed{', latex_str)
results = ""
for match in boxed_matches:
start_index = match.end()
end_index = start_index
stack = 1
# 从 \boxed{ 之后开始搜索,直到找到对应的闭合括号
while stack > 0 and end_index < len(latex_str):
if latex_str[end_index] == '{':
stack += 1
elif latex_str[end_index] == '}':
stack -= 1
end_index += 1
if stack == 0:
# 提取 \boxed{} 内部的内容
content = latex_str[start_index:end_index - 1]
results += content + ","
else:
# 如果括号没有正确闭合,则返回错误信息
raise ValueError("Mismatched braces in LaTeX string.")
# 如果没有匹配到'\boxed{}'字符,则默认提取有内容的文字最后一行中的所有公式部分
if results == "":
last_line_ans = latex_str.strip().split("\n")[-1]
dollar_pattern = r"\$(.*?)\$"
answers = re.findall(dollar_pattern, last_line_ans)
if answers:
for ans in answers:
results += ans + ","
else:
results = latex_str
return results
def sepcial_symbol_replace(expression):
if "\\in " in expression:
expression = expression.split("\\in ")[1]
# 进行特殊字符的替换,这些字符都不影响latex的解析,属于美观/修饰性字符
for signal in self.special_signal_map:
expression = expression.replace(signal, self.special_signal_map[signal])
expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~,。")
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
expression = re.sub(pattern, r'\1', expression)
return expression
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
return exp1, exp2
def can_compute_power(self, expr):
"""
Check if the power expression can be computed.
Parameters:
expr (sympy expression): The expression to check.
Returns:
bool: True if the expression can be computed, False otherwise.
"""
# Check if the expression is a power expression
if isinstance(expr, Pow):
# Extract the base and the exponent
base, exp = expr.as_base_exp()
# Check if the base and the exponent are numbers
if base.is_number and exp.is_number:
# Set a threshold for the maximum size of the exponent
MAX_EXP = 1000 # This threshold can be adjusted based on the computing environment
# Check if the exponent is greater than the threshold
if abs(exp.evalf()) > MAX_EXP:
return False
else:
return True
else:
# If the base or the exponent is not a number, we cannot compute the power
return False
else:
# If the expression is not a power expression, return True as it is not the case we are checking for
return True
def extract_answer(is_chinese, model_output, is_deepseek=False):
# deepseekmath has special answering format
if str(model_output) == 'nan':
model_output = 'nan'
if is_deepseek:
if is_chinese:
matches = re.findall('## 解题答案(.*)', model_output)
else:
matches = re.findall('The answer is: (.*)', model_output)
# 检测是否至少找到一个匹配,如果没有就直接整个送进去找\boxed{}
if matches:
# 如果找到多个匹配,取最后一个
model_answer = matches[-1].strip()
return model_answer
else:
return model_output
if is_chinese:
matches = re.findall('所以最终答案是(.*)', model_output)
else:
matches = re.findall('So the final answer is (.*)', model_output)
# 检测是否至少找到一个匹配,如果没有就直接整个送进去找\boxed{}
if matches:
# 如果找到多个匹配,取最后一个
model_answer = matches[-1].strip()
return model_answer
else:
return model_output
def calculate_merged_accuracy(reference_dir, text_only):
pass
from ...smp import *
from ...utils import can_infer
FAIL_MSG = 'Failed to obtain answer via API.'
def get_gpt4_ICE_for_qspatial():
example_1 = """
Hint: Please answer the question requiring in a tuple format. The tuple should contain a numeric value and a unit,
e.g., (1, m), (2.2, cm), (3.12, meter), at the end.\n
Model response: **Object Identification**
* The object in question is a chair.
* The chair is not visible in the image.
**Conclusion**
The height of the chair cannot be determined from the provided image.\n
Extracted answer: (0, cm)
"""
example_2 = """
Hint: Please answer the question requiring in a tuple format. The tuple should contain a numeric value and a unit,
e.g., (1, inch), (1.2, cm), (3.0, feet), at the end.\n
Model response: **Step 1: Identify the stapler and the recycle bin in the image.**
The stapler is located on the wooden table, and the recycle bin is located on the floor.
**Step 2: Determine the distance between the stapler and the recycle bin.**
The stapler is 0.5 meters from the edge of the table, and the recycle bin is 1.5 meters from the edge of the table.
Therefore, the minimum distance between the stapler and the recycle bin is 1.5 - 0.5 = 1 meter.
**Answer:** 1 m\n
Extracted answer: (1, m)
"""
example_3 = """
Hint: Please answer the question requiring in a tuple format. The tuple should contain a numeric value and a unit,
e.g., (1, foot), (2, cm), (4.3, meter), at the end.\n
Model response: The mirror in the image is approximately 5 feet 4 inches tall.\n
Extracted answer: (64, inch)
"""
example_4 = """
Hint: Please answer the question requiring in a tuple format. The tuple should contain a numeric value and a unit,
e.g., (0.1, cm), (2.9, cm), (0.3, meter), at the end.\n
Model response: The minimum distance between the wooden chair and the chair near the camera in the image is 1.7 feet.\n
Extracted answer: (1.7, feet)
"""
example_5 = """
Hint: Please answer the question requiring in a tuple format. The tuple should contain a numeric value and a unit,
e.g., (5.1, cm), (0.9, cm), (55, mm), at the end.\n
Model response: The height of the painting's bottom edge from the floor is approximately 4.5 feet.\n
Extracted answer: (4.5, feet)
"""
return [example_1, example_2, example_3, example_4, example_5]
def list_to_dict(lst):
return {chr(65 + i): val for i, val in enumerate(lst)}
def post_check(line, prefetch=False):
res = None
ans = line['answer']
response = line['prediction'] if prefetch else line['res']
try:
if line['question_type'] == 'multi_choice':
ans = line['answer_option']
choices = list_to_dict(eval(line['choices']))
res = can_infer(response, choices)
if prefetch:
return res
else:
if line['answer_type'] == 'integer':
res = int(response)
ans = int(line['answer'])
elif line['answer_type'] == 'float':
res = float(response)
ans = float(line['answer'])
else:
res = str(res)
ans = str(ans)
except ValueError:
pass
if res == ans:
return res if prefetch else True
else:
return False
def build_qspatial_gpt4_prompt(line):
task_description = """
Please read the following example.
Then extract the answer from the model response and type it at the end of the prompt.\n
"""
prediction = str(line['prediction'])
prompt = task_description
examples = get_gpt4_ICE_for_qspatial()
for example in examples:
prompt += example + '\n'
prompt += 'Model respone: ' + prediction
prompt += '\nExtracted answer:'
return prompt
def QSpatial_auxeval(model, line):
prompt = build_qspatial_gpt4_prompt(line)
log = ''
retry = 5
for i in range(retry):
prediction = line['prediction']
res = model.generate(prompt, temperature=i * 0.5)
if FAIL_MSG in res:
log += f'Try {i}: output is {prediction}, failed to parse.\n'
else:
log += 'Succeed'
return dict(log=log, res=res)
log += 'All 5 retries failed.\n'
return dict(log=log, res='')
"""
Copied from https://github.com/allenai/allennlp-semparse
Modified from https://github.com/naver-ai/tablevqabench
"""
import re
import unicodedata
import time
from abc import ABCMeta, abstractmethod
from math import isinf, isnan
# Vision Prompts
VWTQ_PROMPT = (
'You are asked to answer questions asked on an image.\n'
'You should answer the question with a single word.\n'
'Example: \n'
'Question: what was the only year mr. wu competed in the olympic games?\n'
'Answer: 2004\n'
'Question: which township in pope county, arkansas has the least amount of water area?\n'
'Answer: Freeman\n'
'If you have multiple answers, please separate them with || marks. Example: Apple||Banana||Tomato\n\n'
'Question: {question}\n'
'Answer:'
)
VTABFACT_PROMPT = (
'You are asked to answer whether the statement is True or False based on given image\n'
'You should only answer True or False.\n'
'Example: \n'
'Statement: the milwaukee buck win 6 game in the 2010 - 11 season\n'
'Answer: True\n'
'Statement: only the top team score above the average of 8.8\n'
'Answer: False\n\n'
'Statement: {question}\n'
'Answer:'
)
FINTABNETQA_PROMPT = (
'You are asked to answer questions asked on a image.\n'
'You should answer the question within a single word or few words.\n'
'If units can be known, the answer should include units such as $, %, million and etc.\n'
'Example: \n'
'Question: What were the total financing originations for the fiscal year ended October 31, 2004?\n'
'Answer: $3,852 million\n'
'Question: What is the time period represented in the table?\n'
'Answer: October 31\n'
'Question: What was the percentage of net sales for selling, general and administrative expenses in 2006?\n'
'Answer: 34.2%\n'
'Question: {question}\n'
'Answer:'
)
def evaluate_tabfact(data, score_keys):
num_examples = 0
num_correct = 0
manual_check = 0
start_time = time.time()
for instance in data:
if instance['prediction'] is None:
instance['prediction'] = 'none'
pred = instance['prediction'].lower()
gt = instance['answer']
num_examples += 1
if 'true' in pred and 'false' in pred:
manual_check += 1
score = None
elif 'true' in pred and gt == '1':
num_correct += 1
score = 1
elif 'false' in pred and gt == '0':
num_correct += 1
score = 1
else:
score = 0
instance['scores'] = {score_keys[0]: score}
if manual_check > 0:
print(f'the number of not properly parsed samples: {manual_check}')
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_wtq(data, score_keys):
num_examples = 0
num_correct = 0
start_time = time.time()
for instance in data:
pred = instance['prediction'].replace('||', '|')
gt = instance['answer']
original_strings = tsv_unescape_list(gt)
target_values = to_value_list(original_strings)
predicted_strings = tsv_unescape_list(pred)
predicted_values = to_value_list(predicted_strings)
correct = check_denotation(target_values, predicted_values)
num_examples += 1
score = 0
if correct:
num_correct += 1
score = 1
instance['scores'] = {score_keys[0]: score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_fintabnet(data, score_keys):
num_examples = 0
num_correct, _num_correct = 0, 0
start_time = time.time()
for instance in data:
pred, preds = fintabnet_normalize(instance['prediction'])
gt, gts = fintabnet_normalize(instance['answer'])
correct = 1 if gt == pred else 0
_correct = any(_pred == _gt for _pred in preds for _gt in gts)
num_examples += 1
score, _score = 0, 0
if correct:
num_correct += 1
score = 1
if _correct:
_num_correct += 1
_score = 1
instance['scores'] = {score_keys[0]: _score, 'exact_score': score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
_Accuracy = round((_num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': ['relieved_accuracy', score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [_Accuracy, Accuracy],
}
return meta
def fintabnet_normalize(s):
s = normalize(s)
remove_words = [
'dollar', 'gallons', 'square feet', 'shares', 'mbtu',
'mbpd', 'mbbls', 'mmbtu', 'unit', 'gwh', 'year', 'mmcf', 'mile', 'mboe'
]
# Data specific filtering using regular expressions
# Remove special characters like $, (, and )
s = re.sub(r'[\$\(\),]', '', s)
# Replace "dollar" with empty string if it's not part of another word
pattern = r'\b(' + '|'.join(remove_words) + r')s?\b'
s = re.sub(pattern, '', s, flags=re.IGNORECASE)
# Unit conversion dictionary with regex patterns for flexibility
unit_conversion = {
r' \bthousand\b': 'e3',
r' \bmillion\b': 'e6',
r' \bbillion\b': 'e9',
r'\bthousand\b': 'e3',
r'\bmillion\b': 'e6',
r'\bbillion\b': 'e9',
r' ?%': 'e-2',
}
# Convert percentages to their decimal representation.
# Applying this after unit_conversion prevents "percent" from being processed
# in cases like "million %", which would be incorrect.
# s = re.sub(r' ?%', 'e-2', s)
# s_percent = re.sub(r' ?%', '', s_percent)
s_unit_free = s
# Iterate over unit_conversion and apply transformations
for pattern, value in unit_conversion.items():
s = re.sub(pattern, value, s)
s_unit_free = re.sub(pattern, '', s_unit_free)
# Attempt to convert to float
try:
return float(s), [float(s), float(s_unit_free)]
except ValueError:
# Return the original string and the error for debugging purposes
return s, [s, s_unit_free]
def normalize(x):
if not isinstance(x, str):
x = x.decode('utf8', errors='ignore')
# Remove diacritics
x = ''.join(
c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn'
)
# Normalize quotes and dashes
x = re.sub(r'[‘’´`]', "'", x)
x = re.sub(r'[“”]', '"', x)
x = re.sub(r'[‐‑‒–—−]', '-', x)
while True:
old_x = x
# Remove citations
x = re.sub(r'((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$', '', x.strip())
# Remove details in parenthesis
x = re.sub(r'(?<!^)( \([^)]*\))*$', '', x.strip())
# Remove outermost quotation mark
x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
if x == old_x:
break
# Remove final '.'
if x and x[-1] == '.':
x = x[:-1]
# Collapse whitespaces and convert to lower case
x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
return x
# Value Types
class Value(object):
__metaclass__ = ABCMeta
# Should be populated with the normalized string
_normalized = None
@abstractmethod
def match(self, other):
"""Return True if the value matches the other value.
Args:
other (Value)
Returns:
a boolean
"""
pass
@property
def normalized(self):
return self._normalized
class StringValue(Value):
def __init__(self, content):
assert isinstance(content, str)
self._normalized = normalize(content)
self._hash = hash(self._normalized)
def __eq__(self, other):
return isinstance(other, StringValue) and self.normalized == other.normalized
def __hash__(self):
return self._hash
def __str__(self):
return 'S' + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
return self.normalized == other.normalized
class NumberValue(Value):
def __init__(self, amount, original_string=None):
assert isinstance(amount, (int, float))
if abs(amount - round(amount)) < 1e-6:
self._amount = int(amount)
else:
self._amount = float(amount)
if not original_string:
self._normalized = str(self._amount)
else:
self._normalized = normalize(original_string)
self._hash = hash(self._amount)
@property
def amount(self):
return self._amount
def __eq__(self, other):
return isinstance(other, NumberValue) and self.amount == other.amount
def __hash__(self):
return self._hash
def __str__(self):
return 'N({})'.format(self.amount) + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, NumberValue):
return abs(self.amount - other.amount) < 1e-6
return False
@staticmethod
def parse(text):
"""Try to parse into a number.
Return:
the number (int or float) if successful; otherwise None.
"""
try:
return int(text)
except ValueError:
try:
amount = float(text)
assert not isnan(amount) and not isinf(amount)
return amount
except ValueError:
return None
class DateValue(Value):
def __init__(self, year, month, day, original_string=None):
"""Create a new DateValue. Placeholders are marked as -1."""
assert isinstance(year, int)
assert isinstance(month, int) and (month == -1 or 1 <= month <= 12)
assert isinstance(day, int) and (day == -1 or 1 <= day <= 31)
assert not (year == month == day == -1)
self._year = year
self._month = month
self._day = day
if not original_string:
self._normalized = '{}-{}-{}'.format(
year if year != -1 else 'xx',
month if month != -1 else 'xx',
day if day != '-1' else 'xx',
)
else:
self._normalized = normalize(original_string)
self._hash = hash((self._year, self._month, self._day))
@property
def ymd(self):
return (self._year, self._month, self._day)
def __eq__(self, other):
return isinstance(other, DateValue) and self.ymd == other.ymd
def __hash__(self):
return self._hash
def __str__(self):
return ('D(%d,%d,%d)' % (self._year, self._month, self._day)) + str(
[self._normalized]
)
__repr__ = __str__
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, DateValue):
return self.ymd == other.ymd
return False
@staticmethod
def parse(text):
"""Try to parse into a date.
Return:
tuple (year, month, date) if successful; otherwise None.
"""
try:
ymd = text.lower().split('-')
assert len(ymd) == 3
year = -1 if ymd[0] in ('xx', 'xxxx') else int(ymd[0])
month = -1 if ymd[1] == 'xx' else int(ymd[1])
day = -1 if ymd[2] == 'xx' else int(ymd[2])
assert not (year == month == day == -1)
assert month == -1 or 1 <= month <= 12
assert day == -1 or 1 <= day <= 31
return (year, month, day)
except:
return None
# Value Instantiation
def to_value(original_string, corenlp_value=None):
"""Convert the string to Value object.
Args:
original_string (basestring): Original string
corenlp_value (basestring): Optional value returned from CoreNLP
Returns:
Value
"""
if isinstance(original_string, Value):
# Already a Value
return original_string
if not corenlp_value:
corenlp_value = original_string
# Number?
amount = NumberValue.parse(corenlp_value)
if amount is not None:
return NumberValue(amount, original_string)
# Date?
ymd = DateValue.parse(corenlp_value)
if ymd is not None:
if ymd[1] == ymd[2] == -1:
return NumberValue(ymd[0], original_string)
else:
return DateValue(ymd[0], ymd[1], ymd[2], original_string)
# String.
return StringValue(original_string)
def to_value_list(original_strings, corenlp_values=None):
"""Convert a list of strings to a list of Values
Args:
original_strings (list[basestring])
corenlp_values (list[basestring or None])
Returns:
list[Value]
"""
assert isinstance(original_strings, (list, tuple, set))
if corenlp_values is not None:
assert isinstance(corenlp_values, (list, tuple, set))
assert len(original_strings) == len(corenlp_values)
return list(
set(to_value(x, y) for (x, y) in zip(original_strings, corenlp_values))
)
else:
return list(set(to_value(x) for x in original_strings))
# Check the Predicted Denotations
def check_denotation(target_values, predicted_values):
"""Return True if the predicted denotation is correct.
Args:
target_values (list[Value])
predicted_values (list[Value])
Returns:
bool
"""
# Check size
if len(target_values) != len(predicted_values):
return False
# Check items
for target in target_values:
if not any(target.match(pred) for pred in predicted_values):
return False
return True
# Batch Mode
def tsv_unescape(x):
"""Unescape strings in the TSV file.
Escaped characters include:
newline (0x10) -> backslash + n
vertical bar (0x7C) -> backslash + p
backslash (0x5C) -> backslash + backslash
Args:
x (str or unicode)
Returns:
a unicode
"""
return x.replace(r'\n', '\n').replace(r'\p', '|').replace('\\\\', '\\')
def tsv_unescape_list(x):
"""Unescape a list in the TSV file.
List items are joined with vertical bars (0x5C)
Args:
x (str or unicode)
Returns:
a list of unicodes
"""
return [tsv_unescape(y) for y in x.split('|')]
from ...smp import *
from .multiple_choice import extract_answer_from_item
from PIL import Image, ImageOps
import numpy as np
sys_prompt = "You are an AI assistant for question answering."
system_prompt_multi_choice = (
"You will receive a multi-choice question, the ground-truth answer and the prediction from a question answering (QA) model. " # noqa
"Your task is to determine whether QA model prediction is correct, based on the question and ground-truth answer. "
"If the prediction is correct, respond \"Correct\". If the prediction is incorrect, respond \"Incorrect\"."
)
system_prompt_caption_matching = (
"You will receive a caption matching question, the ground-truth answer and the prediction from a question answering (QA) model. " # noqa
"Your task is to determine whether QA model prediction is correct, based on the question and ground-truth answer. "
"If the prediction is correct, respond \"Correct\". If the prediction is incorrect, respond \"Incorrect\"."
)
system_prompt_captioning = """
You will receive a video description and a multi-choice question. Your task is to choose the correct answer and briefly explain the reason why you choose the answer. \
If none of the choice candidates are correct or the video description lacks enough information to answer the question, just answer "None of the choices are correct". \
Please organize your response in this format:
```
Reasoning: [Your reason to obtain the answer]
Answer: [Your answer]
```
Here are some examples of video description, multi-choice question and the expected answer:
```
Video Description: A person is palying football.
Multi-Choice Question:
What is the person doing in the video?
A. cooking
B. palying football
C. playing basketball
D. reading book
Reasoning: The video description mentions that the person is playing football.
Answer: B. palying football
Video Description: A bird is flying clockwise.
Multi-Choice Question:
In which direction is the bird flying?
A. backwark
B. counter-clockwise
C. clockwise
D. downward
Reasoning: The video description mentions that the bird is flying clockwise
Answer: C. clockwise
Video Description: An air balloon is inflating.
Multi-Choice Question:
What is happening to the air balloon?
A. exploding
B. getting smaller
C. flying
Reasoning: The video description mentions that the air balloon is inflating, while none of the coices can be explained as inflating.
Answer: None of the choices are correct
```
""" # noqa
system_prompt_YorN = """
You will receive a Yes/No question, the ground-truth answer and the prediction from a question answering (QA) model. \
Your task is to determine whether QA model prediction is correct, based on the question and ground-truth answer. \
If the prediction is correct, respond "Correct". If the prediction is incorrect, respond "Incorrect".
""" # noqa
def eval_rule_caption_matching(line):
# Determine whether the video llm output is correct, based on word matching rules
video_llm_output = line['prediction']
answer = line['answer']
option_strs = eval(line['candidates']) # complete option strings
option_sents = [opt.split(': ')[1] for opt in option_strs] # option sentence
# option index, e.g., Sentence A, Caption A, Option 1
option_inds = [opt.split(': ')[0] for opt in option_strs] + [opt.split(': ')[0].replace('Sentence ', '').replace('Option ', '').replace('Caption ', '') for opt in option_strs] # noqa
video_llm_pred = None
for option_str in option_strs:
if option_str == video_llm_output:
video_llm_pred = option_str
for option_sent in option_sents:
if option_sent == video_llm_output or (') ' in video_llm_output and option_sent == video_llm_output.split(') ')[1]): # noqa
video_llm_pred = option_sent
for option_ind in option_inds:
if option_ind == video_llm_output or option_ind == video_llm_output.replace('.', ''): # noqa
video_llm_pred = option_ind
if video_llm_pred is None:
return "fail"
else:
return 1 if video_llm_pred == answer or video_llm_pred == answer.split(":")[0] or video_llm_pred == answer.split(": ")[1] or video_llm_pred == answer.split(": ")[0].split()[1] else 0 # noqa
def eval_rule_multi_choice(line):
if line['prediction'] == line['answer']:
return 1
elif line['prediction'] in ['A', 'B', 'C', 'D']:
return 1 if line['prediction'] == line['answer'][0] else 0
elif any(line['prediction'].startswith(prefix) for prefix in ['A.', 'B.', 'C.', 'D.']):
return 1 if line['prediction'].split('.')[0] == line['answer'][0] else 0
elif any(line['prediction'].startswith(prefix) for prefix in ['A)', 'B)', 'C)', 'D)']):
return 1 if line['prediction'].split(')')[0] == line['answer'][0] else 0
else:
return "fail"
def eval_rule_YorN(video_llm_output):
# Extract the yes/no predction from the original video llm output
video_llm_output = video_llm_output.lower()
if video_llm_output.startswith("yes"):
return "yes"
elif video_llm_output.startswith("no"):
return "no"
else:
return False
def llm_output_to_rating(llm_output):
if not ('Correct' in llm_output or 'Incorrect' in llm_output):
print(f"Warning: LLM output is not in the correct format: {llm_output}")
rating = 0
return rating
if llm_output.startswith('Correct'):
rating = 1
elif llm_output.startswith('Incorrect'):
rating = 0
elif ('Correct' in llm_output) and ('Incorrect' not in llm_output):
rating = 1
elif 'Incorrect' in llm_output:
rating = 0
return rating
def parse_llm_output(llm_output, gt_answer):
if llm_output == "invalid_request_error" or not llm_output:
eval_result = {"rating": -1, "chatgpt-answer": None, "chatgpt-reasoning": None}
return eval_result
eval_result = {}
lines = llm_output.split("\n")
for line in lines:
line = line.strip()
if "Reasoning" in line:
eval_result['chatgpt-reasoning'] = line.replace("Reasoning:", "").strip()
if "Answer" in line:
eval_result['chatgpt-answer'] = line.replace("Answer:", "").strip()
if "chatgpt-answer" not in eval_result:
eval_result['chatgpt-answer'] = llm_output
if "chatgpt-reasoning" not in eval_result:
eval_result['chatgpt-reasoning'] = None
# Check if the chatgpt answer is the ground-truth answer
# calculate the number of 'A.', 'B.', 'C.', 'D.' in chatgpt-answer
answer_counts = sum(eval_result['chatgpt-answer'].count(prefix) for prefix in ['A.', 'B.', 'C.', 'D.']) # noqa
if eval_result['chatgpt-answer'].split(". ")[0] == gt_answer.split(". ")[0] and answer_counts == 1:
eval_result['rating'] = 1
else:
eval_result['rating'] = 0
return eval_result
def evaluate_tempcompass_mcq(model, line):
eval_rules_dict = {
'caption_matching': eval_rule_caption_matching,
'multi-choice': eval_rule_multi_choice
}
gpt_eval_prompt = {
'multi-choice': '{}\nMulti-Choice Question:\n{}\nGround-Truth Answer: {}\nModel Prediction: {}',
'caption_matching': '{}\nCaption Matching Question:\n{}\nGround-Truth Answer: {}\nModel Prediction: {}'
}
base_prompt = {
'multi-choice': system_prompt_multi_choice,
'caption_matching': system_prompt_caption_matching
}
eval_result = {
"question": line['question'],
"answer": line['answer'],
"prediction": line['prediction'],
"task_type": line['task_type'],
"candidates": line['candidates'],
"match_success": True
}
result = eval_rules_dict[line['task_type']](line)
if result == "fail":
eval_result['match_success'] = False
if model is None:
eval_result['rating'] = 0
else:
prompt_template = gpt_eval_prompt[line['task_type']]
prompt = prompt_template.format(base_prompt[line['task_type']], line['question'], line['answer'], line['prediction']) # noqa
llm_output = model.generate(prompt)
result = llm_output_to_rating(llm_output)
eval_result['chatgpt-response'] = llm_output
eval_result['rating'] = result
else:
eval_result['rating'] = result
return eval_result
def evaluate_tempcompass_captioning(model, line):
prompt = (
f"{system_prompt_captioning}\n"
f"Video Description:{line['prediction']}\n"
f"Multi-Choice Question:\n{line['mc_question']}\n"
)
if model is not None:
llm_output = model.generate(prompt)
eval_result = parse_llm_output(llm_output, gt_answer=line['mc_answer'])
return eval_result
else:
raise ValueError("Model is None, TempCompass Captioning task not supported exact matching") # noqa
def evaluate_tempcompass_YorN(model, line):
prompt = (
f"{system_prompt_YorN}\n"
f"Yes/No Question:\n{line['question']}\n"
f"Ground-Truth Answer: {line['answer']}\n"
f"Model Prediction: {line['prediction']}"
)
result = eval_rule_YorN(line['prediction'])
eval_result = {
"question": line['question'],
"answer": line['answer'],
"prediction": line['prediction'],
"match_success": True
}
if result:
eval_result['rating'] = 1 if result == line['answer'] else 0
elif model is None:
eval_result['match_success'] = False
eval_result['rating'] = 0
else:
eval_result['match_success'] = False
llm_output = model.generate(prompt)
result = llm_output_to_rating(llm_output)
eval_result['chatgpt-response'] = llm_output
eval_result['rating'] = result
return eval_result
def get_dimension_rating(score_file):
data = load(score_file)
result_dict = {}
for idx, item in data.iterrows():
dict_key = item['dim'] + '. ' + item['task_type']
if dict_key not in result_dict:
result_dict[dict_key] = [0,0]
result_dict[dict_key][0] += int(item['score'])
result_dict[dict_key][1] += 1
return result_dict
from ...smp import *
from .multiple_choice import extract_answer_from_item
import numpy as np
import re
FAIL_MSG = 'Failed to obtain answer via API.'
DURATIONS = [
'short',
'medium',
'long',
]
DOMAINS = [
'Knowledge',
'Film & Television',
'Sports Competition',
'Artistic Performance',
'Life Record',
'Multilingual'
]
SUB_CATEGORIES = [
'Humanity & History',
'Literature & Art',
'Biology & Medicine',
'Finance & Commerce',
'Astronomy',
'Geography',
'Law',
'Life Tip',
'Technology',
'Animation',
'Movie & TV Show',
'Documentary',
'News Report',
'Esports',
'Basketball',
'Football',
'Athletics',
'Other Sports',
'Stage Play',
'Magic Show',
'Variety Show',
'Acrobatics',
'Handicraft',
'Food',
'Fashion',
'Daily Life',
'Travel',
'Pet & Animal',
'Exercise',
'Multilingual'
]
TASK_CATEGORIES = [
'Temporal Perception',
'Spatial Perception',
'Attribute Perception',
'Action Recognition',
'Object Recognition',
'OCR Problems',
'Counting Problem',
'Temporal Reasoning',
'Spatial Reasoning',
'Action Reasoning',
'Object Reasoning',
'Information Synopsis',
]
def get_dimension_rating(data_path):
data = load(data_path)
duration_rating = {k: {} for k in DURATIONS}
for duration in DURATIONS + ['overall']:
duration_rating[duration] = {
'overall': '',
'domain': {k: [] for k in DOMAINS},
'sub_category': {k: [] for k in SUB_CATEGORIES},
'task_type': {k: [] for k in TASK_CATEGORIES}
}
for i in range(len(data)):
domain = data.iloc[i]['domain']
sub_ctg = data.iloc[i]['sub_category']
task_ctg = data.iloc[i]['task_type']
duration = data.iloc[i]['duration']
duration_rating[duration]['domain'][domain].append(data.iloc[i]['score'])
duration_rating[duration]['sub_category'][sub_ctg].append(data.iloc[i]['score'])
duration_rating[duration]['task_type'][task_ctg].append(data.iloc[i]['score'])
duration_rating['overall']['domain'][domain].append(data.iloc[i]['score'])
duration_rating['overall']['sub_category'][sub_ctg].append(data.iloc[i]['score'])
duration_rating['overall']['task_type'][task_ctg].append(data.iloc[i]['score'])
for duration in DURATIONS + ['overall']:
overall_res_dur = f'{np.mean([x for x in sum(duration_rating[duration]["domain"].values(), []) if x >= 0]):.3f}'
duration_rating[duration]['overall'] = overall_res_dur
for domain in DOMAINS:
domain_res_dur = f'{np.mean([x for x in duration_rating[duration]["domain"][domain] if x >= 0]):.3f}'
duration_rating[duration]['domain'][domain] = domain_res_dur
for sub_ctg in SUB_CATEGORIES:
sub_res_dur = f'{np.mean([x for x in duration_rating[duration]["sub_category"][sub_ctg] if x >= 0]):.3f}'
duration_rating[duration]['sub_category'][sub_ctg] = sub_res_dur
for task_ctg in TASK_CATEGORIES:
task_res_dur = f'{np.mean([x for x in duration_rating[duration]["task_type"][task_ctg] if x >= 0]):.3f}'
duration_rating[duration]['task_type'][task_ctg] = task_res_dur
return duration_rating
def extract_option(model, input_item, dataset_name):
options = input_item['question'].split('\n')[1:]
for id, option in enumerate(options):
option_id = chr(ord('A') + id) + '.'
if option.find(option_id) >= 0:
input_item[chr(ord('A') + id)] = option[option.find(option_id) + len(option_id):].strip('. \n')
return extract_answer_from_item(model, input_item, dataset_name)['opt']
def extract_characters_regex(s):
s = s.strip()
answer_prefixes = [
'The best answer is',
'The correct answer is',
'The answer is',
'The answer',
'The best option is'
'The correct option is',
'Best answer:'
'Best option:',
'Answer:',
'Option:',
]
for answer_prefix in answer_prefixes:
s = s.replace(answer_prefix, '')
if len(s.split()) > 10 and not re.search('[ABCD]', s):
return ''
matches = re.search(r'[ABCD]', s)
if matches is None:
return ''
return matches[0]
# Copyright (c) OpenMMLab. All rights reserved.
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
# Copyright (c) 2014, Aishwarya Agrawal
from ...smp import *
from typing import Optional
def _process_digit_article(inText):
outText = []
tempText = inText.lower().split()
articles = ['a', 'an', 'the']
manualMap = {
'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10',
}
contractions = {
'aint': "ain't",
'arent': "aren't",
'cant': "can't",
'couldve': "could've",
'couldnt': "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
'didnt': "didn't",
'doesnt': "doesn't",
'dont': "don't",
'hadnt': "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
'hasnt': "hasn't",
'havent': "haven't",
'hed': "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
'hes': "he's",
'howd': "how'd",
'howll': "how'll",
'hows': "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
'Im': "I'm",
'Ive': "I've",
'isnt': "isn't",
'itd': "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
'itll': "it'll",
"let's": "let's",
'maam': "ma'am",
'mightnt': "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
'mightve': "might've",
'mustnt': "mustn't",
'mustve': "must've",
'neednt': "needn't",
'notve': "not've",
'oclock': "o'clock",
'oughtnt': "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
'shant': "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
'shouldve': "should've",
'shouldnt': "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": 'somebodyd',
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
'somebodyll': "somebody'll",
'somebodys': "somebody's",
'someoned': "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
'someonell': "someone'll",
'someones': "someone's",
'somethingd': "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
'somethingll': "something'll",
'thats': "that's",
'thered': "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
'therere': "there're",
'theres': "there's",
'theyd': "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
'theyll': "they'll",
'theyre': "they're",
'theyve': "they've",
'twas': "'twas",
'wasnt': "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
'weve': "we've",
'werent': "weren't",
'whatll': "what'll",
'whatre': "what're",
'whats': "what's",
'whatve': "what've",
'whens': "when's",
'whered': "where'd",
'wheres': "where's",
'whereve': "where've",
'whod': "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
'wholl': "who'll",
'whos': "who's",
'whove': "who've",
'whyll': "why'll",
'whyre': "why're",
'whys': "why's",
'wont': "won't",
'wouldve': "would've",
'wouldnt': "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
'yall': "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
'youd': "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
'youll': "you'll",
'youre': "you're",
'youve': "you've",
}
for word in tempText:
word = manualMap.setdefault(word, word)
if word not in articles:
outText.append(word)
for wordId, word in enumerate(outText):
if word in contractions:
outText[wordId] = contractions[word]
outText = ' '.join(outText)
return outText
def hit_calculate(result, dataset_name, anls_threshold=0.5):
if listinstr(['TextVQA'], dataset_name):
return [np.mean(x['match']) for x in result]
elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
elif listinstr(['ChartQA', 'OCRVQA'], dataset_name):
return [np.max(x['match']) for x in result]
else: # default using vqa_score to calculate score
return [np.mean(x['match']) for x in result]
# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
prediction: str,
max_relative_change: float = 0.05) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
Args:
target: Target string.
prediction: Predicted string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
def _to_float(text: str) -> Optional[float]:
try:
if text.endswith('%'):
# Convert percentages to floats.
return float(text.rstrip('%')) / 100.0
else:
return float(text)
except ValueError:
return None
prediction = str(prediction)
target = str(target)
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float - target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def anls_compute(groundtruth, prediction):
gt_answer = ' '.join(groundtruth.strip().lower().split())
det_answer = ' '.join(prediction.strip().lower().split())
dist = levenshtein_distance(gt_answer, det_answer)
length = max(len(groundtruth.upper()), len(prediction.upper()))
values = 0.0 if length == 0 else float(dist) / float(length)
return values
def process_answer(answer):
answer = answer.replace('\n', ' ')
answer = answer.replace('\t', ' ')
answer = answer.strip()
answer = process_punctuation(answer)
answer = _process_digit_article(answer)
return answer
def process_line(line, method='vqa_score'):
ret = {}
if istype(line['answer'], list):
answers = eval(line['answer'])
else:
answers = [line['answer']]
if method == 'vqa_score':
ret['gt'] = [process_answer(x) for x in answers]
ret['pred'] = process_answer(line['prediction'])
ret['match'] = []
for current_idx, gtAnsDatum in enumerate(ret['gt']):
otherGTAns = [
item for ret_gt_idx, item in enumerate(ret['gt'])
if ret_gt_idx != current_idx
]
matchingAns = [
item for item in otherGTAns if item == ret['pred']
]
acc = min(1, float(len(matchingAns)) / 3)
ret['match'].append(acc)
elif method == 'anls':
ret['gt'] = answers
ret['pred'] = line['prediction']
ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']]
elif method == 'relaxed_accuracy':
ret['gt'] = answers
ret['pred'] = line['prediction'].strip()
ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']]
elif method == 'accuracy':
ret['gt'] = answers
ret['pred'] = line['prediction'].strip()
ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']]
else: # default using vqa_score to calculate score
ret['gt'] = [process_answer(x) for x in answers]
ret['pred'] = process_answer(line['prediction'])
ret['match'] = [x == ret['pred'] for x in ret['gt']]
return ret
from ...smp import *
def AMBER_rating(data_file):
data = load(data_file)
stats = defaultdict(dict)
lt = len(data)
category_mapping = {
'discriminative-attribute-state': 'Attribute',
'discriminative-attribute-number': 'Attribute',
'discriminative-attribute-action': 'Attribute',
'discriminative-hallucination': 'Existence',
'discriminative-relation': 'Relation',
'relation': 'Relation'
}
for i in range(lt):
item = data.iloc[i]
category = item['category']
image_path = item['image_path']
score = item['score']
new_category = category_mapping.get(category, category)
if image_path not in stats[new_category]:
stats[new_category][image_path] = []
stats[new_category][image_path].append(score)
def acc(key):
res = stats[key]
values = []
for val in res.values():
values.extend(val)
return np.mean(values) * 100
scores = {}
for k in stats:
scores[k] = acc(k)
scores['Avg ACC'] = np.mean(list(scores.values()))
ret = d2df(scores)
return ret
def MME_rating(data_file):
data = load(data_file)
stats = defaultdict(dict)
lt = len(data)
for i in range(lt):
item = data.iloc[i]
category = item['category']
image_path = item['image_path']
score = item['score']
if image_path not in stats[category]:
stats[category][image_path] = []
stats[category][image_path].append(score)
def acc(key, mode='normal'):
res = stats[key]
values = []
for val in res.values():
if mode == 'normal':
values.extend(val)
elif mode == 'plus':
values.append(val[0] * val[1])
return np.mean(values) * 100
scores = {}
for k in stats:
scores[k] = acc(k) + acc(k, 'plus')
super_cates = dict(
perception=[
'OCR', 'artwork', 'celebrity', 'color', 'count', 'existence',
'landmark', 'position', 'posters', 'scene'
],
reasoning=['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation']
)
ret = {}
for sc, cate_list in super_cates.items():
base = 0
for c in cate_list:
base += scores[c]
ret[sc] = base
ret.update(scores)
ret = d2df(ret)
return ret
def Hallusion_rating(data_file):
def calc_fAcc(data):
res = defaultdict(list)
lt = len(data)
for i in range(lt):
line = data.iloc[i]
res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score'])
return np.mean([np.all(x) for x in res.values()]) * 100
def calc_qAcc(data):
res = defaultdict(list)
lt = len(data)
for i in range(lt):
line = data.iloc[i]
res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score'])
return np.mean([np.all(x) for x in res.values()]) * 100
def calc_aAcc(data):
return np.mean(data['score']) * 100
data = load(data_file)
data['set_id'] = [x.split('_')[3] for x in data['index']]
data['figure_id'] = [x.split('_')[4] for x in data['index']]
data['question_id'] = [x.split('_')[5] for x in data['index']]
res = dict(split=[], aAcc=[], fAcc=[], qAcc=[])
res['split'].append('Overall')
res['aAcc'].append(calc_aAcc(data))
res['fAcc'].append(calc_fAcc(data))
res['qAcc'].append(calc_qAcc(data))
if 'category' in data:
cates = list(set(data['category']))
for c in cates:
sub = data[data['category'] == c]
res['split'].append(c)
res['aAcc'].append(calc_aAcc(sub))
res['fAcc'].append(calc_fAcc(sub))
res['qAcc'].append(calc_qAcc(sub))
if 'l2-category' in data:
cates = list(set(data['l2-category']))
for c in cates:
sub = data[data['l2-category'] == c]
res['split'].append(c)
res['aAcc'].append(calc_aAcc(sub))
res['fAcc'].append(calc_fAcc(sub))
res['qAcc'].append(calc_qAcc(sub))
ret = pd.DataFrame(res)
return ret
def POPE_rating(data_file):
def cal_f1_score(y_true, y_pred):
tp = sum((y_true == 1) & (y_pred == 1))
fp = sum((y_true == 0) & (y_pred == 1))
fn = sum((y_true == 1) & (y_pred == 0))
precision = tp / (tp + fp) if (tp + fp) != 0 else 0
recall = tp / (tp + fn) if (tp + fn) != 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
return f1_score, precision, recall
data = load(data_file)
data = data.assign(category=data['category'].str.split(',')).explode('category')
data['index'] = range(len(data))
res = dict(split=[], Overall=[], acc=[], precision=[], recall=[])
y_true = np.array([1 if i == 'Yes' else 0 for i in data['answer']])
y_pred = np.array([1 if i == 'Yes' else 0 for i in data['extracted']])
f1_score, precision, recall = cal_f1_score(y_true, y_pred)
res['split'].append('Overall')
res['Overall'].append(f1_score * 100)
res['acc'].append(np.mean(data['score']) * 100)
res['precision'].append(precision * 100)
res['recall'].append(recall * 100)
if 'category' in data:
cates = list(set(data['category']))
cates = [c for c in cates if not pd.isna(c)]
for c in cates:
sub = data[data['category'] == c]
y_true = np.array([1 if i == 'Yes' else 0 for i in sub['answer']])
y_pred = np.array([1 if i == 'Yes' else 0 for i in sub['extracted']])
f1_score, precision, recall = cal_f1_score(y_true, y_pred)
res['split'].append(c)
res['Overall'].append(f1_score * 100)
res['acc'].append(np.mean(sub['score']) * 100)
res['precision'].append(precision * 100)
res['recall'].append(recall * 100)
ret = pd.DataFrame(res)
return ret
def default_rating(data_file):
data = load(data_file)
res = {}
res['Overall'] = np.mean(data['score']) * 100
if 'category' in data:
cates = list(set(data['category']))
cates = [c for c in cates if not pd.isna(c)]
cates.sort()
for c in cates:
sub = data[data['category'] == c]
res[c] = np.mean(sub['score']) * 100
if 'l2-category' in data:
cates = list(set(data['l2-category']))
cates = [c for c in cates if not pd.isna(c)]
cates.sort()
for c in cates:
sub = data[data['l2-category'] == c]
res[c] = np.mean(sub['score']) * 100
ret = d2df(res)
return ret
def YOrN_match_prompt(line):
tmpl = (
'You are an AI assistant who will help me to match an answer with two options of a question. '
'The options are only Yes / No. '
'You are provided with a question and an answer, '
'and you need to find which option (Yes / No) is most similar to the answer. '
'If the meaning of all options are significantly different from the answer, output Unknown. '
'Your should output a single word among the following 3 choices: Yes, No, Unknown.\n'
'Example 1: \n'
"Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is 'Hello'.\nYour output: Yes\n"
'Example 2: \n'
"Question: Is the word in this image 'Hello'?\n"
"Answer: The word in this image is not 'Hello'.\nYour output: No\n"
'Example 3: \n'
'Question: {}?\nAnswer: {}\nYour output: '
)
return tmpl.format(line['question'], line['prediction'])
def YOrN_Extraction(output):
s = output.lower()
words = process_punctuation(s).split()
if 'yes' in words and 'no' not in words:
return 'Yes'
if 'yes' not in words and 'no' in words:
return 'No'
return 'Unknown'
def YOrN_auxeval(model, line):
prompt = YOrN_match_prompt(line)
retry = 5
for i in range(retry):
output = model.generate(prompt, temperature=0.5 * i)
ans = YOrN_Extraction(output)
if ans != 'Unknown':
return ans
return 'Unknown'
import uuid
from functools import partial
from .image_base import ImageBaseDataset
from ..smp import *
rouge = None
nlp_en = None
nlp_zh = None
nlp = None
def initialize():
import evaluate
import spacy
global rouge, nlp_en, nlp_zh, nlp
try:
rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4()))
except Exception as e:
logging.critical(f'{type(e)}: {e}')
logging.critical('Please first `pip install rouge_score`.')
try:
nlp_en = spacy.load('en_core_web_sm')
except Exception as e:
logging.warning(f'{type(e)}: {e}')
logging.warning('Will automatically download en_core_web_sm via spacy.')
spacy.cli.download('en_core_web_sm')
nlp_en = spacy.load('en_core_web_sm')
try:
nlp_zh = spacy.load('zh_core_web_sm')
except Exception as e:
logging.warning(f'{type(e)}: {e}')
logging.warning('Will automatically download zh_core_web_sm via spacy.')
spacy.cli.download('zh_core_web_sm')
nlp_zh = spacy.load('zh_core_web_sm')
nlp = {'en': nlp_en, 'zh': nlp_zh}
def rough_filter(answer_text):
if "I can't" in answer_text:
return False
elif 'I cannot' in answer_text:
return False
elif 'sorry' in answer_text.lower():
return False
if '无法' in answer_text:
return False
elif '抱歉' in answer_text:
return False
else:
return True
def zero_template(crossed_text):
return {
'crossed_text': crossed_text,
'max_sim_val': 0,
'max_sim_string': '',
'precision': 0,
'recall': 0,
'f1': 0,
'jaccard': 0,
'rouge1': 0,
'exact_match': 0,
}
def tokenize(text, language):
"""
Tokenize the text and return the tokens.
Parameters:
text (str): The text to tokenize.
language (str): The language of the text.
Returns:
list: The list of tokens.
"""
assert language in ['en', 'zh']
nlp_language = nlp[language]
processed_text = nlp_language(text)
return [token.text for token in processed_text]
def find_best_match(needle, hay, language, rouge):
"""
Finds the best matching n-gram in the haystack for the given needle.
Parameters:
needle (str): The string to find.
hay (str): The text to search within.
Returns:
tuple: The highest similarity value and the best matching string.
"""
assert language in ['en', 'zh']
from nltk.util import ngrams
from difflib import SequenceMatcher as SM
tokens_hay = tokenize(hay, language)
tokens_needle = tokenize(needle, language)
splitter = '' if language == 'zh' else ' '
ngrams_ = ngrams(tokens_hay, len(tokens_needle))
max_sim_val = 0
max_sim_string = ''
max_sim_ngram = []
tokens_needle_set = set(tokens_needle)
ngrams_hasjoint = [
ngram
for ngram in ngrams_
if not set(ngram).isdisjoint(tokens_needle_set)
]
for ngram in ngrams_hasjoint:
hay_ngram = splitter.join(ngram)
similarity = SM(None, hay_ngram, needle).ratio()
if similarity > max_sim_val:
max_sim_val = similarity
max_sim_string = hay_ngram
max_sim_ngram = ngram
# Evaluate
if len(max_sim_ngram) == 0:
return {
'crossed_text': needle,
'max_sim_val': 0,
'max_sim_string': '',
'precision': 0,
'recall': 0,
'f1': 0,
'jaccard': 0,
'rouge1': 0,
'exact_match': 0,
}
pred_set = set(max_sim_ngram)
ref_set = set(tokens_needle)
correct_tokens = pred_set.intersection(ref_set)
len_correct_tokens = len(correct_tokens)
precision = len_correct_tokens / len(pred_set)
recall = len_correct_tokens / len(ref_set)
if (precision + recall) == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)
union = pred_set.union(ref_set)
jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
rouge_1 = rouge.compute(
predictions=[max_sim_string],
references=[needle],
tokenizer=partial(tokenize, language=language),
rouge_types=['rouge1'],
)['rouge1']
exact_match = float(list(max_sim_ngram) == list(tokens_needle))
out = {
'crossed_text': needle,
'max_sim_string': max_sim_string,
'max_sim_val': max_sim_val,
'precision': precision,
'recall': recall,
'f1': f1,
'jaccard': jaccard,
'rouge1': rouge_1,
'exact_match': exact_match,
}
return out
def process_match_single_new(
image_id, prediction, answer, language, progress
):
"""
process the inference results for a single image and calculate the metrics
Parameters:
image_id (int): The image id (question id).
prediction (str): The prediction text.
answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image.
language (str): The language of the text. Can be "en" or "zh".
rouge (rouge): The rouge metric object.
progress (multiprocessing.Queue): The progress queue.
Returns:
tuple: The image id (question_id, int) and the result per id (dict of dict of dict).
"""
result_per_id = {image_id: {}}
if isinstance(answer, str):
answer = eval(answer)
assert isinstance(answer, list)
result = prediction.split('Assistant: ')[-1]
for i, crossed_text in enumerate(answer):
if rough_filter(result):
find_best_match_result = find_best_match(
crossed_text, result, language, rouge
)
if i == 0:
result_per_id[image_id] = {str(i): find_best_match_result}
else:
result_per_id[image_id][str(i)] = find_best_match_result
else:
if i == 0:
result_per_id[image_id] = {str(i): zero_template(crossed_text)}
else:
result_per_id[image_id][str(i)] = zero_template(crossed_text)
progress.put(1)
return image_id, result_per_id
class VCRDataset(ImageBaseDataset):
TYPE = 'VQA'
URL_PREFIX = 'https://huggingface.co/datasets/vcr-org'
DATASET_URL = {
'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv',
'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv',
'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv',
'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv',
'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv',
'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv',
'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv',
'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv',
'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv',
'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv',
'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv',
'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv',
}
DATASET_MD5 = {
'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261',
'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0',
'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1',
'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4',
'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958',
'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e',
'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048',
'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5',
'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea',
'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7',
'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b',
'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c',
}
def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True):
super().__init__(dataset, skip_noimg)
initialize()
self.language = 'en' if 'EN' in dataset else 'zh'
self.difficulty = 'easy' if 'EASY' in dataset else 'hard'
# def build_prompt(self, line):
# msgs = super().build_prompt(line)
# assert msgs[-1]['type'] == 'text'
# if self.language == 'zh':
# msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。'
# else:
# msgs[-1]['value'] += ('What is the covered texts in the image? '
# 'Please restore the covered texts without outputting the explanations.')
# return msgs
def evaluate(self, eval_file, **judge_kwargs):
import multiprocessing
vcr_score_list = {'Exact_Match': [], 'Jaccard': []}
vcr_score = {'Exact_Match': 0, 'Jaccard': 0}
logger = get_logger('Evaluation')
data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
pool = multiprocessing.Pool()
manager = multiprocessing.Manager()
progress_queue = manager.Queue()
results = []
overall_results = {str(image_id): {} for image_id in range(len(lines))}
for instance_id, instance in enumerate(lines):
results.append(
pool.apply_async(
process_match_single_new,
args=(
str(instance_id),
instance['prediction'],
instance['answer'],
self.language,
progress_queue,
),
)
)
pool.close()
# Display progress bar
for _ in tqdm(range(len(results))):
progress_queue.get()
pool.join()
# Merging results into overall_result
for result in results:
image_id, result_per_id = result.get()
overall_results[str(image_id)].update(result_per_id[image_id])
for blank_id_str in result_per_id[image_id].keys():
vcr_score_list['Exact_Match'].append(
result_per_id[image_id][blank_id_str]['exact_match']
)
vcr_score_list['Jaccard'].append(
result_per_id[image_id][blank_id_str]['jaccard']
)
vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match'])
vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard'])
results_out = {
k: v for i in range(len(results)) for k, v in results[i].get()[1].items()
}
results_with_metrics = {
'Exact_Match': vcr_score['Exact_Match'],
'Jaccard': vcr_score['Jaccard'],
'Predictions': results_out,
}
score_pth = eval_file.replace(
'.xlsx', f'{self.language}_{self.difficulty}_score.json'
)
dump(results_with_metrics, score_pth)
logger.info(
f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}'
)
logger.info('Score: ')
for key, value in vcr_score.items():
logger.info('{}:{}'.format(key, value))
from abc import abstractmethod
from ..smp import *
class VideoBaseDataset:
MODALITY = 'VIDEO'
def __init__(self,
dataset='MMBench-Video',
pack=False,
nframe=0,
fps=-1):
try:
import decord
except Exception as e:
logging.critical(f'{type(e)}: {e}')
logging.critical('Please install decord via `pip install decord`.')
self.dataset_name = dataset
ret = self.prepare_dataset(dataset)
assert ret is not None
lmu_root = LMUDataRoot()
self.frame_root = osp.join(lmu_root, 'images', dataset)
os.makedirs(self.frame_root, exist_ok=True)
self.frame_tmpl = 'frame-{}-of-{}.jpg'
self.frame_tmpl_fps = 'frame-{}-of-{}-{}fps.jpg'
self.data_root = ret['root']
self.data_file = ret['data_file']
self.data = load(self.data_file)
assert 'question' in self.data and 'video' in self.data
videos = list(set(self.data['video']))
videos.sort()
self.videos = videos
self.pack = pack
self.nframe = nframe
self.fps = fps
if self.fps > 0 and self.nframe > 0:
raise ValueError('fps and nframe should not be set at the same time')
if self.fps <= 0 and self.nframe <= 0:
raise ValueError('fps and nframe should be set at least one valid value')
def __len__(self):
return len(self.videos) if self.pack else len(self.data)
def __getitem__(self, idx):
if self.pack:
assert idx < len(self.videos)
sub_data = self.data[self.data['video'] == self.videos[idx]]
return sub_data
else:
assert idx < len(self.data)
return dict(self.data.iloc[idx])
def frame_paths(self, video):
frame_root = osp.join(self.frame_root, video)
os.makedirs(frame_root, exist_ok=True)
return [osp.join(frame_root, self.frame_tmpl.format(i, self.nframe)) for i in range(1, self.nframe + 1)]
def frame_paths_fps(self, video, num_frames):
frame_root = osp.join(self.frame_root, video)
os.makedirs(frame_root, exist_ok=True)
return [osp.join(frame_root,
self.frame_tmpl_fps.format(i, num_frames, self.fps)) for i in range(1, num_frames + 1)]
def save_video_frames(self, video):
if self.fps > 0:
vid_path = osp.join(self.data_root, video + '.mp4')
vid = decord.VideoReader(vid_path)
# 计算视频的总帧数和总时长
total_frames = len(vid)
video_fps = vid.get_avg_fps()
total_duration = total_frames / video_fps
# 计算需要提取的总帧数
required_frames = int(total_duration * self.fps)
# 计算提取帧的间隔
step_size = video_fps / self.fps
# 计算提取帧的索引
indices = [int(i * step_size) for i in range(required_frames)]
# 提取帧并保存
frame_paths = self.frame_paths_fps(video, len(indices))
flag = np.all([osp.exists(p) for p in frame_paths])
if flag:
return frame_paths
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths
else:
frame_paths = self.frame_paths(video)
flag = np.all([osp.exists(p) for p in frame_paths])
if flag:
return frame_paths
vid_path = osp.join(self.data_root, video + '.mp4')
vid = decord.VideoReader(vid_path)
step_size = len(vid) / (self.nframe + 1)
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth):
im.save(pth)
return frame_paths
# Return a list of dataset names that are supported by this class, can override
@classmethod
def supported_datasets(cls):
return ['MMBench-Video', 'Video-MME', 'MVBench', 'MVBench_MP4', 'LongVideoBench']
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
@abstractmethod
def evaluate(self, eval_file, **judge_kwargs):
pass
@abstractmethod
def build_prompt(self, idx):
pass
@abstractmethod
def prepare_dataset(self, dataset):
# The prepare_dataset function should return a dictionary containing:
# `root` (directory that containing video files)
# `data_file` (the TSV dataset file)
pass
from ..smp import *
from .video_base import VideoBaseDataset
class ConcatVideoDataset(VideoBaseDataset):
# This dataset takes multiple dataset names as input and aggregate them into a single dataset.
# Each single dataset should not have a field named `SUB_DATASET`
DATASET_SETS = {}
def __init__(self, dataset, **kwargs):
from . import build_dataset
datasets = self.DATASET_SETS[dataset]
self.dataset_map = {}
# The name of the compliation
self.dataset_name = dataset
self.datasets = datasets
self.nframe = kwargs.get('nframe', 0)
self.fps = kwargs.get('fps', -1)
for dname in datasets:
dataset = build_dataset(dname, **kwargs)
assert dataset is not None, dataset
self.dataset_map[dname] = dataset
TYPES = [x.TYPE for x in self.dataset_map.values()]
MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
# assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
self.TYPE = TYPES
self.MODALITY = MODALITIES[0]
data_all = []
for dname in datasets:
data = self.dataset_map[dname].data
data['SUB_DATASET'] = [dname] * len(data)
data_all.append(data)
data = pd.concat(data_all)
data['original_index'] = data.pop('index')
data['index'] = np.arange(len(data))
self.data = data
def build_prompt(self, line, video_llm):
if isinstance(line, int):
line = self.data.iloc[line]
idx = line['original_index']
dname = line['SUB_DATASET']
org_data = self.dataset_map[dname].data
org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
return self.dataset_map[dname].build_prompt(org_line, video_llm)
def dump_image(self, line):
# Assert all images are pre-dumped
assert 'image' not in line
assert 'image_path' in line
tgt_path = toliststr(line['image_path'])
return tgt_path
@classmethod
def supported_datasets(cls):
return [] # list(cls.DATASET_SETS)
def evaluate(self, eval_file, **judge_kwargs):
suffix = eval_file.split('.')[-1]
# First, split the eval_file by dataset
data_all = load(eval_file)
for dname in self.datasets:
tgt = eval_file.replace(self.dataset_name, dname)
data_sub = data_all[data_all['SUB_DATASET'] == dname]
data_sub.pop('index')
data_sub['index'] = data_sub.pop('original_index')
data_sub.pop('SUB_DATASET')
dump(data_sub, tgt)
# Then, evaluate each dataset separately
results_all = {}
for dname in self.datasets:
tgt = eval_file.replace(self.dataset_name, dname)
res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
results_all.update(res)
result = pd.DataFrame(results_all, index=['success', 'overall'])
result = result.T
for idx, item in result.iterrows():
result.loc[idx, 'acc'] = round(item['success'] / item['overall'] * 100, 1)
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
dump(result, score_file)
return result
from vlmeval.dataset import *
from functools import partial
mmbench_video_dataset = {
'MMBench_Video_8frame_nopack': partial(MMBenchVideo, dataset='MMBench-Video', nframe=8, pack=False),
'MMBench_Video_8frame_pack': partial(MMBenchVideo, dataset='MMBench-Video', nframe=8, pack=True),
'MMBench_Video_16frame_nopack': partial(MMBenchVideo, dataset='MMBench-Video', nframe=16, pack=False),
'MMBench_Video_1fps_nopack': partial(MMBenchVideo, dataset='MMBench-Video', fps=1.0, pack=False),
'MMBench_Video_1fps_pack': partial(MMBenchVideo, dataset='MMBench-Video', fps=1.0, pack=True)
}
mvbench_dataset = {
'MVBench_8frame': partial(MVBench, dataset='MVBench', nframe=8),
# MVBench not support fps, but MVBench_MP4 does
'MVBench_MP4_8frame': partial(MVBench_MP4, dataset='MVBench_MP4', nframe=8),
'MVBench_MP4_1fps': partial(MVBench_MP4, dataset='MVBench_MP4', fps=1.0),
}
videomme_dataset = {
'Video-MME_8frame': partial(VideoMME, dataset='Video-MME', nframe=8),
'Video-MME_8frame_subs': partial(VideoMME, dataset='Video-MME', nframe=8, use_subtitle=True),
'Video-MME_1fps': partial(VideoMME, dataset='Video-MME', fps=1.0),
'Video-MME_0.5fps': partial(VideoMME, dataset='Video-MME', fps=0.5),
'Video-MME_0.5fps_subs': partial(VideoMME, dataset='Video-MME', fps=0.5, use_subtitle=True),
}
longvideobench_dataset = {
'LongVideoBench_8frame': partial(LongVideoBench, dataset='LongVideoBench', nframe=8),
'LongVideoBench_8frame_subs': partial(LongVideoBench, dataset='LongVideoBench', nframe=8, use_subtitle=True),
'LongVideoBench_1fps': partial(LongVideoBench, dataset='LongVideoBench', fps=1.0),
'LongVideoBench_0.5fps': partial(LongVideoBench, dataset='LongVideoBench', fps=0.5),
'LongVideoBench_0.5fps_subs': partial(LongVideoBench, dataset='LongVideoBench', fps=0.5, use_subtitle=True)
}
mlvu_dataset = {
'MLVU_8frame': partial(MLVU, dataset='MLVU', nframe=8),
'MLVU_1fps': partial(MLVU, dataset='MLVU', fps=1.0)
}
tempcompass_dataset = {
'TempCompass_8frame': partial(TempCompass, dataset='TempCompass', nframe=8),
'TempCompass_1fps': partial(TempCompass, dataset='TempCompass', fps=1.0),
'TempCompass_0.5fps': partial(TempCompass, dataset='TempCompass', fps=0.5)
}
supported_video_datasets = {}
dataset_groups = [
mmbench_video_dataset, mvbench_dataset, videomme_dataset, longvideobench_dataset,
mlvu_dataset, tempcompass_dataset
]
for grp in dataset_groups:
supported_video_datasets.update(grp)
from huggingface_hub import snapshot_download
from ..smp import *
from .video_base import VideoBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
FAIL_MSG = 'Failed to obtain answer via API.'
def unwrap_hf_pkl(pth, suffix='.mp4'):
base_dir = os.path.join(pth, 'video_pkl/')
target_dir = os.path.join(pth, 'video/')
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
pickle_files.sort()
if not os.path.exists(target_dir):
os.makedirs(target_dir, exist_ok=True)
for pickle_file in pickle_files:
with open(pickle_file, 'rb') as file:
video_data = pickle.load(file)
# For each video file in the pickle file, write its contents to a new mp4 file
for video_name, video_content in video_data.items():
output_path = os.path.join(target_dir, f'{video_name}{suffix}')
with open(output_path, 'wb') as output_file:
output_file.write(video_content)
print('The video file has been restored and stored from the pickle file.')
else:
print('The video file already exists.')
class VideoMME(VideoBaseDataset):
MD5 = '85bdd91f9b29a99354c23b97ab7c113c'
SYS = ''
FRAMES_TMPL_NOSUB = """
These are the frames of a video. \
Select the best answer to the following multiple-choice question based on the video. \
Respond with only the letter (A, B, C, or D) of the correct option.
"""
FRAMES_TMPL_SUB = """
These are the frames of a video. \
This video's subtitles are listed below:
{}
Select the best answer to the following multiple-choice question based on the video. \
Respond with only the letter (A, B, C, or D) of the correct option.
"""
TYPE = 'Video-MCQ'
def __init__(self, dataset='Video-MME', use_subtitle=False, nframe=0, fps=-1):
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
self.use_subtitle = use_subtitle
self.dataset_name = dataset
@classmethod
def supported_datasets(cls):
return ['Video-MME']
def prepare_dataset(self, dataset_name='Video-MME', repo_id='lmms-lab/Video-MME'):
def check_integrity(pth):
data_file = osp.join(pth, f'{dataset_name}.tsv')
if not os.path.exists(data_file):
return False
if md5(data_file) != self.MD5:
return False
data = load(data_file)
for video_pth in data['video_path']:
if not osp.exists(osp.join(pth, video_pth)):
return False
return True
cache_path = get_cache_path(repo_id)
if cache_path is not None and check_integrity(cache_path):
dataset_path = cache_path
else:
def unzip_hf_zip(pth):
import zipfile
base_dir = pth
target_dir = os.path.join(pth, 'video/')
zip_files = [
os.path.join(base_dir, file) for file in os.listdir(base_dir)
if file.endswith('.zip') and file.startswith('video')
]
zip_files.sort()
if not os.path.exists(target_dir):
os.makedirs(target_dir, exist_ok=True)
for zip_file in zip_files:
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
for member in zip_ref.namelist():
# Check if the member is a file (not a directory)
if not member.endswith('/'):
# Extract the file to the specified directory
source = zip_ref.open(member)
target = open(os.path.join(target_dir, os.path.basename(member)), 'wb')
with source, target:
target.write(source.read())
print('The video file has been restored and stored from the zip file.')
else:
print('The video file already exists.')
subtitle_zip_file = os.path.join(base_dir, 'subtitle.zip')
subtitle_target_dir = os.path.join(base_dir, 'subtitle')
if not os.path.exists(subtitle_target_dir):
os.makedirs(subtitle_target_dir, exist_ok=True)
with zipfile.ZipFile(subtitle_zip_file, 'r') as zip_ref:
for member in zip_ref.namelist():
# Check if the member is a file (not a directory)
if not member.endswith('/'):
# Extract the file to the specified directory
source = zip_ref.open(member)
target = open(os.path.join(subtitle_target_dir, os.path.basename(member)), 'wb')
with source, target:
target.write(source.read())
print('The subtitle file has been restored and stored from the zip file.')
else:
print('The subtitle file already exists.')
def generate_tsv(pth):
data_file = osp.join(pth, f'{dataset_name}.tsv')
if os.path.exists(data_file) and md5(data_file) == self.MD5:
return
data_file = pd.read_parquet(os.path.join(pth, 'videomme/test-00000-of-00001.parquet'))
data_file = data_file.assign(index=range(len(data_file)))
data_file['video'] = data_file['videoID']
data_file['video_path'] = data_file['videoID'].apply(lambda x: f'./video/{x}.mp4')
data_file['subtitle_path'] = data_file['videoID'].apply(lambda x: f'./subtitle/{x}.srt')
data_file['candidates'] = data_file['options'].apply(lambda x: x.tolist())
data_file = data_file[['index', 'video', 'video_path', 'duration', 'domain', 'candidates',
'sub_category', 'task_type', 'subtitle_path', 'question', 'answer']]
data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
if modelscope_flag_set():
from modelscope import dataset_snapshot_download
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
else:
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
unzip_hf_zip(dataset_path)
generate_tsv(dataset_path)
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
return dict(data_file=data_file, root=dataset_path)
def save_video_frames(self, video, video_llm=False):
vid_path = osp.join(self.data_root, 'video', video + '.mp4')
vid = decord.VideoReader(vid_path)
video_info = {
'fps': vid.get_avg_fps(),
'n_frames': len(vid),
}
if self.nframe > 0 and self.fps < 0:
step_size = len(vid) / (self.nframe + 1)
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
frame_paths = self.frame_paths(video)
elif self.fps > 0:
# not constrained by num_frames, get frames by fps
total_duration = video_info['n_frames'] / video_info['fps']
required_frames = int(total_duration * self.fps)
step_size = video_info['fps'] / self.fps
indices = [int(i * step_size) for i in range(required_frames)]
frame_paths = self.frame_paths_fps(video, len(indices))
flag = np.all([osp.exists(p) for p in frame_paths])
if not flag:
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
if not osp.exists(pth) and not video_llm:
im.save(pth)
return frame_paths, indices, video_info
def build_prompt(self, line, video_llm):
if isinstance(line, int):
assert line < len(self)
line = self.data.iloc[line]
frames, indices, video_info = self.save_video_frames(line['video'], video_llm)
if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])):
import pysubs2
subs = pysubs2.load(osp.join(self.data_root, line['subtitle_path']), encoding='utf-8')
subtitles = []
for seleced_frame_id in indices:
sub_text = ''
cur_time = pysubs2.make_time(fps=video_info['fps'], frames=seleced_frame_id)
for sub in subs:
if sub.start < cur_time and sub.end > cur_time:
sub_text = sub.text.replace('\\N', ' ')
break
if sub_text.strip():
subtitles.append(sub_text)
subtitles = '\n'.join(subtitles)
else:
subtitles = ''
message = [dict(type='text', value=self.SYS)]
if video_llm:
message.append(dict(type='video', value=osp.join(self.data_root, 'video', line['video'] + '.mp4')))
else:
for im in frames:
message.append(dict(type='image', value=im))
text_prompt = self.FRAMES_TMPL_NOSUB if not self.use_subtitle else self.FRAMES_TMPL_SUB.format(subtitles)
message.append(dict(type='text', value=text_prompt))
line['question'] += '\n' + '\n'.join(eval(line['candidates']))
prompt = 'Question: {}\nAnswer: '.format(line['question'])
message.append(dict(type='text', value=prompt))
return message
# It returns a dictionary
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
from .utils.videomme import get_dimension_rating, extract_characters_regex, extract_option
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
tgt_file = eval_file.replace('.xlsx', '_rating.json')
score_file = eval_file.replace('.xlsx', '_score.xlsx')
if not osp.exists(score_file):
model = judge_kwargs.get('model', 'exact_matching')
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
if model == 'exact_matching':
model = None
elif gpt_key_set():
model = build_judge(**judge_kwargs)
if not model.working():
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
warnings.warn(DEBUG_MESSAGE)
model = None
else:
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None
res = {} if not osp.exists(tmp_file) else load(tmp_file)
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
data = load(eval_file)
data_un = data[~pd.isna(data['prediction'])]
for idx in data['index']:
ans = data.loc[data['index'] == idx, 'answer'].values[0]
pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
if extract_characters_regex(pred) == '':
extract_pred = extract_option(
model,
data.loc[data['index'] == idx].to_dict(orient='records')[0],
'Video-MME'
)
data.loc[idx, 'score'] = int(extract_pred == ans)
else:
data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
rejected = [x for x in data['score'] if x == -1]
print(
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
f'failed to obtain the score for another {len(rejected)} questions. '
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
)
dump(data, score_file)
rating = get_dimension_rating(score_file)
dump(rating, tgt_file)
return rating
import re
from functools import partial
from .image_base import ImageBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
from ..smp import *
from ..utils import track_progress_rich
SYSTEM_PROMPT = """\
Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user \
prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate \
which assistant's answer is better.
Begin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any \
answers.
When evaluating the assistants' answers, compare both assistants' answers with your answer. \
You must identify and correct any mistakes or inaccurate information.
Then consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly \
responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one \
interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than \
providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate \
to what is being asked. Concise means the response is clear and not verbose or excessive.
Then consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing \
important information in the assistants' answers that would be beneficial to include when responding to the user \
prompt.
After providing your explanation, you must output only one of the following choices as your final verdict with a label:
1. Assistant A is significantly better: [[A>>B]]
2. Assistant A is slightly better: [[A>B]]
3. Tie, relatively the same: [[A=B]]
4. Assistant B is slightly better: [[B>A]]
5. Assistant B is significantly better: [[B>>A]]
Example output: "My final verdict is tie: [[A=B]]".\
"""
PROMPT_TEMPLATE = """\
"<|User Prompt|>\n{question}
<|The Start of Assistant A's Answer|>\n{answer_1}\n<|The End of Assistant A's Answer|>
<|The Start of Assistant B's Answer|>\n{answer_2}\n<|The End of Assistant B's Answer|>
"""
REGEX_PATTERN = re.compile("\[\[([AB<>=]+)\]\]") # noqa: W605
def get_score(judgement, pattern=REGEX_PATTERN):
matches = pattern.findall(judgement)
matches = [m for m in matches if m != ""]
if len(set(matches)) == 0:
return None, True
elif len(set(matches)) == 1:
return matches[0].strip("\n"), False
else:
return None, True
def WildVision_auxeval(model, line):
config = dict(question=line['question'], answer_1=line['A'], answer_2=line['B'])
prompt = PROMPT_TEMPLATE.format(**config)
prefix = 'data:image/jpeg;base64,'
img = prefix + line['image']
messages = [
dict(type='text', value=prompt),
dict(type='image', value=img)
]
retry = 2
while retry:
resp = model.generate(messages)
score, try_again = get_score(resp)
if not try_again:
break
retry -= 1
if score is None:
return 'Unknown'
return score
class WildVision(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'WildVision': 'https://opencompass.openxlab.space/utils/VLMEval/WildVision.tsv'
}
DATASET_MD5 = {'WildVision': 'b38f80156d49411c594772866b0d0b52'}
score_map = {
'A>>B': -2,
'A>B': -1,
'A=B': 0,
'B>A': 1,
'B>>A': 2
}
# Given one data record, return the built prompt (a multi-modal message), can override
def build_prompt(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
if self.meta_only:
tgt_path = toliststr(line['image_path'])
else:
tgt_path = self.dump_image(line)
question = line['question']
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
# WildVision adopts text first
msgs = [dict(type='text', value=question)] + msgs
return msgs
@classmethod
def gen_eval_base(self, eval_file, b64_map):
data = load(eval_file)
data['B'] = data.pop('prediction')
data['A'] = data.pop('claude3_sonnet')
data['image'] = [b64_map[x] for x in data['index']]
return data
# rev = cp.deepcopy(data)
# rev['A'] = data['B']
# rev['B'] = data['A']
# rev['index'] = [x + '_rev' for x in data['index']]
# return pd.concat([data, rev], ignore_index=True)
# It returns a DataFrame
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
# We adopt pairwise evaluation (twice for a pair) for this dataset
suffix = eval_file.split('.')[-1]
model = judge_kwargs['model']
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
nproc = judge_kwargs.pop('nproc', 4)
if not osp.exists(storage):
raw_data = WildVision('WildVision').data
b64_map = {x: y for x, y in zip(raw_data['index'], raw_data['image'])}
data = self.gen_eval_base(eval_file, b64_map)
judge_kwargs['system_prompt'] = SYSTEM_PROMPT
judge_kwargs['temperature'] = 0
judge_kwargs['img_detail'] = 'high'
judge_kwargs['timeout'] = 300
model = build_judge(max_tokens=4096, **judge_kwargs)
assert model.working(), ('WildVision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = [line['index'] for line in lines]
ans = load(tmp_file) if osp.exists(tmp_file) else {}
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]
if len(indices):
new_results = track_progress_rich(
WildVision_auxeval,
tups,
nproc=nproc,
chunksize=nproc,
keys=indices,
save=tmp_file,
)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
ans[k] = v
data['score'] = [ans[idx] for idx in data['index']]
data.pop('image')
dump(data, storage)
data = load(storage)
lt = len(data)
scores = defaultdict(lambda: 0)
for i in range(lt):
item = data.iloc[i]
if item['score'] not in self.score_map:
score = 0
else:
score = self.score_map[item['score']]
if '_rev' in item['index']:
score = -score
scores[score] += 1
name_map = {
2: 'Much Better',
1: 'Better',
0: 'Tie',
-1: 'Worse',
-2: 'Much Worse'
}
scores = {name_map[k]: v for k, v in scores.items()}
much_better = scores.get('Much Better', 0)
better = scores.get('Better', 0)
worse = scores.get('Worse', 0)
much_worse = scores.get('Much Worse', 0)
scores['Reward'] = (
100 * much_better + 50 * better - 50 * worse - 100 * much_worse
) / lt
scores['Win Rate'] = (better + much_better) / lt
scores = {k: [v] for k, v in scores.items()}
scores = pd.DataFrame(scores)
dump(scores, score_file)
return scores
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
# Only API model is accepted
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
data = dataset.data
if index_set is not None:
data = data[data['index'].isin(index_set)]
model = supported_VLM[model_name]() if isinstance(model, str) else model
assert getattr(model, 'is_api', False)
if hasattr(model, 'set_dump_image'):
model.set_dump_image(dataset.dump_image)
lt, indices = len(data), list(data['index'])
structs = []
for i in range(lt):
item = data.iloc[i]
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
assert hasattr(model, 'build_prompt')
struct = model.build_prompt(item, dataset=dataset_name)
else:
struct = dataset.build_prompt(item)
structs.append(struct)
# structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
gen_func = model.generate
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
dataset_name = dataset.dataset_name
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
res = load(prev_file) if osp.exists(prev_file) else {}
if osp.exists(out_file):
res.update(load(out_file))
rank, world_size = get_rank_and_world_size()
sheet_indices = list(range(rank, len(dataset), world_size))
lt = len(sheet_indices)
data = dataset.data.iloc[sheet_indices]
data_indices = [i for i in data['index']]
# If finished, will exit without building the model
all_finished = True
for i in range(lt):
idx = data.iloc[i]['index']
if idx not in res:
all_finished = False
if all_finished:
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return
# Data need to be inferred
data = data[~data['index'].isin(res)]
lt = len(data)
model = supported_VLM[model_name]() if isinstance(model, str) else model
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
supp = infer_data_api(
model=model,
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
index_set=set(indices),
api_nproc=api_nproc)
for idx in indices:
assert idx in supp
res.update(supp)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
else:
model.set_dump_image(dataset.dump_image)
for i in tqdm(range(lt)):
idx = data.iloc[i]['index']
if idx in res:
continue
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
else:
struct = dataset.build_prompt(data.iloc[i])
response = model.generate(message=struct, dataset=dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 10 == 0:
dump(res, out_file)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx')
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
if osp.exists(result_file):
if rank == 0:
data = load(result_file)
results = {k: v for k, v in zip(data['index'], data['prediction'])}
if not ignore_failed:
results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
dump(results, prev_file)
if world_size > 1:
dist.barrier()
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model=model, work_dir=work_dir, model_name=model_name, dataset=dataset,
out_file=out_file, verbose=verbose, api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
data = dataset.data
for x in data['index']:
assert x in data_all
data['prediction'] = [str(data_all[x]) for x in data['index']]
if 'image' in data:
data.pop('image')
dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
if world_size > 1:
dist.barrier()
return model
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
def chat_mt(model, messages, dataset_name):
assert len(messages) % 2 == 0
nturn = len(messages) // 2
utter_stack = []
predictions = []
for i in range(nturn):
utter = messages[2 * i]
utter_stack.append(utter)
try:
resp = model.chat(utter_stack, dataset=dataset_name)
utter_stack.append(dict(role='assistant', content=resp))
except Exception as e:
resp = FAIL_MSG + str(e)
utter_stack.append(dict(role='assistant', content=resp))
predictions.append(resp)
return predictions
# Only API model is accepted
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
data = dataset.data
if index_set is not None:
data = data[data['index'].isin(index_set)]
model = supported_VLM[model_name]() if isinstance(model, str) else model
assert getattr(model, 'is_api', False)
assert hasattr(model, 'chat_inner')
lt, indices = len(data), list(data['index'])
structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
structs = [dict(model=model, messages=struct, dataset_name=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(chat_mt, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
dataset_name = dataset.dataset_name
res = {}
if osp.exists(out_file):
res.update(load(out_file))
rank, world_size = get_rank_and_world_size()
sheet_indices = list(range(rank, len(dataset), world_size))
lt = len(sheet_indices)
data = dataset.data.iloc[sheet_indices]
data_indices = [i for i in data['index']]
# If finished, will exit without building the model
all_finished = True
for i in range(lt):
idx = data.iloc[i]['index']
if idx not in res:
all_finished = False
if all_finished:
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return
# Data need to be inferred
data = data[~data['index'].isin(res)]
lt = len(data)
model = supported_VLM[model_name]() if isinstance(model, str) else model
assert hasattr(model, 'chat_inner')
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
supp = infer_data_api(
model=model,
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
index_set=set(indices),
api_nproc=api_nproc)
for idx in indices:
assert idx in supp
res.update(supp)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
else:
model.set_dump_image(dataset.dump_image)
for i in tqdm(range(lt)):
idx = data.iloc[i]['index']
if idx in res:
continue
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
else:
struct = dataset.build_prompt(data.iloc[i])
response = chat_mt(model, struct, dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 20 == 0:
dump(res, out_file)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job_mt(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.tsv')
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model=model, model_name=model_name,work_dir=work_dir, dataset=dataset,
out_file=out_file, verbose=verbose, api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
data = dataset.data
for x in data['index']:
assert x in data_all
data['prediction'] = [data_all[x] for x in data['index']]
if 'image' in data:
data.pop('image')
dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
return model
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.utils import track_progress_rich
from vlmeval.smp import *
FAIL_MSG = 'Failed to obtain answer via API.'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
parser.add_argument('--model', type=str, nargs='+', required=True)
parser.add_argument('--nproc', type=int, default=4, required=True)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
# Only API model is accepted
def infer_data_api(model, work_dir, model_name, dataset, samples_dict={}, api_nproc=4):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
model = supported_VLM[model_name]() if isinstance(model, str) else model
assert getattr(model, 'is_api', False)
indices = list(samples_dict.keys())
structs = [dataset.build_prompt(samples_dict[idx], video_llm=getattr(model, 'VIDEO_LLM', False)) for idx in indices]
packstr = 'pack' if getattr(dataset, 'pack', False) else 'nopack'
if dataset.nframe > 0:
out_file = f'{work_dir}/{model_name}_{dataset_name}_{dataset.nframe}frame_{packstr}_supp.pkl'
else:
out_file = f'{work_dir}/{model_name}_{dataset_name}_{dataset.fps}fps_{packstr}_supp.pkl'
res = load(out_file) if osp.exists(out_file) else {}
structs = [s for i, s in zip(indices, structs) if i not in res or res[i] == FAIL_MSG]
indices = [i for i in indices if i not in res or res[i] == FAIL_MSG]
gen_func = model.generate
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
res = load(out_file)
return res
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
res = load(out_file) if osp.exists(out_file) else {}
rank, world_size = get_rank_and_world_size()
dataset_name = dataset.dataset_name
sample_indices = list(dataset.videos) if getattr(dataset, 'pack', False) else list(dataset.data['index'])
samples = list(dataset.videos) if getattr(dataset, 'pack', False) else list(range(len(dataset.data)))
sample_map = {i: s for i, s in zip(sample_indices, samples)}
sample_indices_sub = sample_indices[rank::world_size]
if np.all([idx in res for idx in sample_indices_sub]):
return model
sample_indices_subrem = [x for x in sample_indices_sub if x not in res]
model = supported_VLM[model_name]() if isinstance(model, str) else model
is_api = getattr(model, 'is_api', False)
if is_api:
assert world_size == 1
supp = infer_data_api(
model=model,
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
samples_dict={k: sample_map[k] for k in sample_indices_subrem},
api_nproc=api_nproc)
for k in sample_indices_subrem:
assert k in supp
res.update(supp)
dump(res, out_file)
return model
assert not getattr(dataset, 'pack', False), 'Current model not supported pack mode!'
for i, idx in tqdm(enumerate(sample_indices_subrem)):
if idx in res:
continue
if getattr(model, 'nframe', None) is not None and getattr(model, 'nframe', 0) > 0:
if dataset.nframe > 0:
if getattr(model, 'nframe', 0) != dataset.nframe:
print(f'{model_name} is a video-llm model, nframe is set to {dataset.nframe}, not using default')
setattr(model, 'nframe', dataset.nframe)
elif getattr(model, 'fps', 0) == 0:
raise ValueError(f'fps is not suitable for {model_name}')
else:
setattr(model, 'nframe', None)
if getattr(model, 'fps', None) is not None and getattr(model, 'fps', 0) > 0:
if dataset.fps > 0:
if getattr(model, 'fps', 0) != dataset.fps:
print(f'{model_name} is a video-llm model, fps is set to {dataset.fps}, not using default')
setattr(model, 'fps', dataset.fps)
elif getattr(model, 'nframe', 0) == 0:
raise ValueError(f'nframe is not suitable for {model_name}')
else:
setattr(model, 'fps', None)
if 'SUB_DATASET' in dataset.data.iloc[sample_map[idx]]:
dataset_name = dataset.data.iloc[sample_map[idx]]['SUB_DATASET']
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
if dataset.nframe == 0:
raise ValueError(f'nframe must be set for custom prompt, fps is not suitable for {model_name}')
struct = model.build_prompt(
dataset.data.iloc[sample_map[idx]], dataset=dataset, video_llm=getattr(model, 'VIDEO_LLM', False)
)
else:
struct = dataset.build_prompt(
sample_map[idx], video_llm=getattr(model, 'VIDEO_LLM', False)
)
response = model.generate(message=struct, dataset=dataset_name)
torch.cuda.empty_cache()
if verbose:
print(response, flush=True)
res[idx] = response
if (i + 1) % 20 == 0:
dump(res, out_file)
res = {k: res[k] for k in sample_indices_sub}
dump(res, out_file)
return model
# A wrapper for infer_data, do the pre & post processing
def infer_data_job_video(
model,
work_dir,
model_name,
dataset,
result_file_name,
verbose=False,
api_nproc=4):
dataset_name = dataset.dataset_name
rank, world_size = get_rank_and_world_size()
result_file = osp.join(work_dir, result_file_name)
# Dump Predictions to Prev File if result file exists
if osp.exists(result_file):
return model
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{osp.splitext(result_file_name)[0]}.pkl')
out_file = tmpl.format(rank)
model = infer_data(
model=model,
model_name=model_name,
work_dir=work_dir,
dataset=dataset,
out_file=out_file,
verbose=verbose,
api_nproc=api_nproc)
if world_size > 1:
dist.barrier()
if rank == 0:
data_all = {}
for i in range(world_size):
data_all.update(load(tmpl.format(i)))
meta = dataset.data
if dataset_name == 'MMBench-Video' and getattr(dataset, 'pack', False):
meta, vstats = dataset.load_pack_answers(data_all)
print(f'Statitics of Pack Video Inference: {vstats}')
else:
for x in meta['index']:
assert x in data_all
meta['prediction'] = [str(data_all[x]) for x in meta['index']]
if 'image' in meta:
meta.pop('image')
dump(meta, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
return model
from .file import *
from .vlm import *
from .misc import *
from .log import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment