Commit 7a60e044 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1185 canceled with stages
import os
import io
import pandas as pd
import numpy as np
import string
from uuid import uuid4
import os.path as osp
import base64
from PIL import Image
from .file import load, dump
Image.MAX_IMAGE_PIXELS = 1e9
def mmqa_display(question, target_size=512):
question = {k.lower(): v for k, v in question.items()}
keys = list(question.keys())
keys = [k for k in keys if k not in ['index', 'image']]
images = question['image']
if isinstance(images, str):
images = [images]
idx = question.pop('index', 'XXX')
print(f'INDEX: {idx}')
for im in images:
image = decode_base64_to_image(im, target_size=target_size)
display(image) # noqa: F821
for k in keys:
try:
if not pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
except ValueError:
if False in pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
def encode_image_to_base64(img, target_size=-1):
# if target_size == -1, will not do resizing
# else, will set the max_size ot (target_size, target_size)
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
if target_size > 0:
img.thumbnail((target_size, target_size))
img.save(tmp)
with open(tmp, 'rb') as image_file:
image_data = image_file.read()
ret = base64.b64encode(image_data).decode('utf-8')
os.remove(tmp)
return ret
def encode_image_file_to_base64(image_path, target_size=-1):
image = Image.open(image_path)
return encode_image_to_base64(image, target_size=target_size)
def decode_base64_to_image(base64_string, target_size=-1):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
if image.mode in ('RGBA', 'P'):
image = image.convert('RGB')
if target_size > 0:
image.thumbnail((target_size, target_size))
return image
def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
image = decode_base64_to_image(base64_string, target_size=target_size)
image.save(image_path)
def build_option_str(option_dict):
s = 'There are several options: \n'
for c, content in option_dict.items():
if not pd.isna(content):
s += f'{c}. {content}\n'
return s
def isimg(s):
return osp.exists(s) or s.startswith('http')
def read_ok(img_path):
if not osp.exists(img_path):
return False
try:
im = Image.open(img_path)
assert im.size[0] > 0 and im.size[1] > 0
return True
except:
return False
def gpt_key_set():
openai_key = os.environ.get('OPENAI_API_KEY', None)
return isinstance(openai_key, str) and openai_key.startswith('sk-')
def apiok(wrapper):
s = wrapper.generate('Hello!')
return wrapper.fail_msg not in s
def circular_pred(df, extract_func=None):
if extract_func is None:
extract_func = lambda x: x # noqa: E731
df = df.sort_values('index')
from vlmeval.utils import can_infer_option
shift = int(1e6)
choices = [extract_func(x) for x in df['prediction']]
pred_map = {i: c for i, c in zip(df['index'], choices)}
flag_map = {i: True for i in pred_map if i < 1e6}
valid_map = {i: True for i in pred_map if i < 1e6}
for i in df['index']:
if i >= shift and pred_map[i] and pred_map[i - shift]:
if (
pred_map[i] not in list(string.ascii_uppercase) or # noqa: W504
pred_map[i - shift] not in list(string.ascii_uppercase)
):
valid_map[i % shift] = False
continue
if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
continue
else:
flag_map[i % shift] = False
flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
flags = list(flag_map.values())
return np.mean(flags)
from .matching_util import can_infer, can_infer_option, can_infer_text
from .mp_util import track_progress_rich
from .custom_prompt import CustomPrompt
from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
from .dataset import TSVDataset, split_MMMU, MMMU_result_transfer
__all__ = [
'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
'split_MMMU', 'abbr2full'
]
from ..smp import *
from .dataset_config import img_root_map
from abc import abstractmethod
class CustomPrompt:
@abstractmethod
def use_custom_prompt(self, dataset):
raise NotImplementedError
@abstractmethod
def build_prompt(self, line, dataset):
raise NotImplementedError
def dump_image(self, line, dataset):
ROOT = LMUDataRoot()
assert isinstance(dataset, str)
img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
os.makedirs(img_root, exist_ok=True)
if isinstance(line['image'], list):
tgt_path = []
assert 'image_path' in line
for img, im_name in zip(line['image'], line['image_path']):
path = osp.join(img_root, im_name)
if not read_ok(path):
decode_base64_to_image_file(img, path)
tgt_path.append(path)
else:
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
if not read_ok(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
return tgt_path
import pandas as pd
import hashlib
from ..smp import *
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
from .custom_prompt import CustomPrompt
from .matching_util import can_infer
def isliststr(s):
return (s[0] == '[') and (s[-1] == ']')
def check_md5(data_path, dataset):
if dataset not in dataset_md5_dict:
warnings.warn(f'We do not have an md5 record for dataset {dataset}, skip the md5 check. ')
return True
assert osp.exists(data_path)
with open(data_path, 'rb') as f:
hash = hashlib.new('md5')
for chunk in iter(lambda: f.read(2**20), b''):
hash.update(chunk)
if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
return True
else:
warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
return False
def split_MMMU(msgs):
text, images = None, []
for s in msgs:
if s['type'] == 'image':
images.append(s['value'])
elif s['type'] == 'text':
assert text is None
text = s['value']
text_segs = text.split('<image ')
segs = [dict(type='text', value=text_segs[0])]
for i, seg in enumerate(text_segs):
if i == 0:
continue
assert istype(seg[0], int) and seg[1] == '>'
image_idx = int(seg[0]) - 1
segs.append(dict(type='image', value=images[image_idx]))
segs.append(dict(type='text', value=seg[2:]))
return segs
def MMMU_result_transfer(result_path):
res = {}
result_data = load(result_path)
mcq = result_data['A'].notna()
lt = len(result_data)
for i in range(lt):
line = result_data.iloc[i]
if mcq[i]:
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
prediction = line['prediction']
infer_prediction = can_infer(prediction, options)
res[line['id']] = infer_prediction
else:
res[line['id']] = line['prediction']
result_json = result_path.replace('.xlsx', '.json')
dump(res, result_json)
return result_json
class TSVDataset(CustomPrompt):
def __init__(self, dataset='MMBench', skip_noimg=True):
self.data_root = LMUDataRoot()
assert osp.exists(self.data_root)
self.dataset = dataset
self.dataset_type = DATASET_TYPE(dataset)
if dataset in dataset_URLs:
url = dataset_URLs[dataset]
file_name = url.split('/')[-1]
data_path = osp.join(self.data_root, file_name)
if osp.exists(data_path) and check_md5(data_path, dataset):
pass
elif osp.isfile(url):
# If url is actually a file path, use it directly
data_path = url
else:
warnings.warn('The dataset tsv is not downloaded')
download_file(url, data_path)
else:
data_path = osp.join(self.data_root, dataset + '.tsv')
assert osp.exists(data_path)
data = load(data_path)
self.skip_noimg = skip_noimg
if skip_noimg and 'image' in data:
data = data[~pd.isna(data['image'])]
# Prompt for Captioning
if listinstr(['COCO'], dataset):
data['question'] = [(
'Please describe this image in general. Directly provide the description, '
'do not include prefix like "This image depicts". '
)] * len(data)
data['index'] = [str(x) for x in data['index']]
self.meta_only = True
if 'image' in data:
data['image'] = [str(x) for x in data['image']]
image_map = {x: y for x, y in zip(data['index'], data['image'])}
for k in image_map:
if len(image_map[k]) <= 64:
idx = image_map[k]
assert idx in image_map and len(image_map[idx]) > 64
image_map[k] = image_map[idx]
data['image'] = [
eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
for k in data['index']
]
self.meta_only = False
if 'image_path' in data:
data['image_path'] = [
eval(pths) if isliststr(pths) else pths for pths in data['image_path']
]
if np.all([istype(x, int) for x in data['index']]):
data['index'] = [int(x) for x in data['index']]
self.data = data
def __len__(self):
return len(self.data)
def build_prompt(self, line, dataset=None):
if dataset is None:
dataset = self.dataset
if isinstance(line, int):
line = self.data.iloc[line]
if self.meta_only:
tgt_path = line['image_path']
else:
tgt_path = self.dump_image(line, dataset)
prompt = line['question']
if DATASET_TYPE(dataset) == 'multi-choice':
question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += 'Please select the correct answer from the options above. \n'
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['ocrvqa', 'textvqa', 'chartqa', 'docvqa'], dataset.lower()):
prompt += '\nPlease try to answer the question with short words or phrases if possible\n.'
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs
def display(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
mmqa_display(line)
from ..smp import listinstr
dataset_URLs = {
# MMBench v1.0
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Internal Only
'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Internal Only
# MMBench v1.1
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN_V11.tsv',
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN_V11.tsv',
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN_V11.tsv',
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN_V11.tsv',
'MMBench_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_V11.tsv', # Internal Only
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN_V11.tsv', # Internal Only
# CCBench
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
'CORE_MM': 'https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv',
'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv',
'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv',
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv',
'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
}
dataset_md5_dict = {
# MMBench v1.0
'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
# MMBench v1.1
'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
# CCBench
'CCBench': '1de88b4257e7eee3f60b18d45eda6f07',
'MME': 'b36b43c3f09801f5d368627fb92187c3',
'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
'CORE_MM': '8a8da2f2232e79caf98415bfdf0a202d',
'MMVet': '748aa6d4aa9d4de798306a63718455e3',
'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
'LLaVABench': 'd382a093f749a697820d3dadd61c8428',
'OCRBench': 'e953d98a987cc6e26ef717b61260b778',
'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
'RealWorldQA': '92321028d2bc29040284b6674721e48f',
'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
}
img_root_map = {k: k for k in dataset_URLs}
img_root_map.update({
# MMBench v1.0
'MMBench_DEV_EN': 'MMBench',
'MMBench_TEST_EN': 'MMBench',
'MMBench_DEV_CN': 'MMBench',
'MMBench_TEST_CN': 'MMBench',
'MMBench': 'MMBench', # Internal Only
'MMBench_CN': 'MMBench', # Internal Only
# MMBench v1.1
'MMBench_DEV_EN_V11': 'MMBench_V11',
'MMBench_TEST_EN_V11': 'MMBench_V11',
'MMBench_DEV_CN_V11': 'MMBench_V11',
'MMBench_TEST_CN_V11': 'MMBench_V11',
'MMBench_V11': 'MMBench_V11', # Internal Only
'MMBench_CN_V11': 'MMBench_V11', # Internal Only
'COCO_VAL': 'COCO',
'OCRVQA_TEST': 'OCRVQA',
'OCRVQA_TESTCORE': 'OCRVQA',
'TextVQA_VAL': 'TextVQA',
'MMMU_DEV_VAL': 'MMMU',
'MMMU_TEST': 'MMMU',
'MathVista_MINI': 'MathVista',
'HallusionBench': 'Hallusion',
'DocVQA_VAL': 'DocVQA',
'DocVQA_TEST': 'DocVQA_TEST',
'OCRBench': 'OCRBench',
'ChartQA_TEST': 'ChartQA_TEST',
'InfoVQA_VAL': 'InfoVQA_VAL',
'InfoVQA_TEST': 'InfoVQA_TEST',
'MMStar': 'MMStar',
'RealWorldQA': 'RealWorldQA',
'POPE': 'POPE',
})
assert set(dataset_URLs) == set(img_root_map)
def DATASET_TYPE(dataset):
# Dealing with Custom Dataset
dataset = dataset.lower()
if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d', 'mmstar', 'realworldqa'], dataset):
return 'multi-choice'
elif listinstr(['mme', 'hallusion', 'pope'], dataset):
return 'Y/N'
elif 'coco' in dataset:
return 'Caption'
elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'infovqa', 'llavabench',
'mmvet', 'ocrbench'], dataset):
return 'VQA'
else:
if dataset not in dataset_URLs:
import warnings
warnings.warn(f"Dataset {dataset} not found in dataset_URLs, will use 'multi-choice' as the default TYPE.")
return 'multi-choice'
else:
return 'QA'
def abbr2full(s):
datasets = [x for x in img_root_map]
ins = [s in d for d in datasets]
if sum(ins) == 1:
for d in datasets:
if s in d:
return d
else:
return s
import string
import copy as cp
import os
from ..smp import *
def can_infer_option(answer, choices):
verbose = os.environ.get('VERBOSE', 0)
# Choices is a dictionary
if 'Failed to obtain answer via API' in answer:
return False
reject_to_answer = [
"Sorry, I can't help with images of people yet.",
"I can't process this file.",
"I'm sorry, but without the image provided",
'Cannot determine the answer'
]
for err in reject_to_answer:
if err in answer:
return 'Z'
def count_choice(splits, choices, prefix='', suffix=''):
cnt = 0
for c in choices:
if prefix + c + suffix in splits:
cnt += 1
return cnt
answer_mod = cp.copy(answer)
chars = '.()[],:;!*#{}'
for c in chars:
answer_mod = answer_mod.replace(c, ' ')
splits = [x.strip() for x in answer_mod.split()]
count = count_choice(splits, choices)
if count == 1:
for ch in choices:
if 'A' in splits and len(splits) > 3 and verbose:
logger = get_logger('Evaluation')
logger.info(f'A might be a quantifier in the string: {answer}.')
return False
if ch in splits:
return ch
elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
return 'Z'
return False
def can_infer_text(answer, choices):
answer = answer.lower()
assert isinstance(choices, dict)
for k in choices:
assert k in string.ascii_uppercase
choices[k] = str(choices[k]).lower()
cands = []
for k in choices:
if choices[k] in answer:
cands.append(k)
if len(cands) == 1:
return cands[0]
return False
def can_infer(answer, choices):
answer = str(answer)
copt = can_infer_option(answer, choices)
return copt if copt else can_infer_text(answer, choices)
from multiprocessing import Pool
import os
from typing import Callable, Iterable, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn)
from rich.text import Text
import os.path as osp
import portalocker
from ..smp import load, dump
class _Worker:
"""Function wrapper for ``track_progress_rich``"""
def __init__(self, func) -> None:
self.func = func
def __call__(self, inputs):
inputs, idx = inputs
if not isinstance(inputs, (tuple, list, dict)):
inputs = (inputs, )
if isinstance(inputs, dict):
return self.func(**inputs), idx
else:
return self.func(*inputs), idx
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
"""Skip calculating remaining time for the first few times.
Args:
skip_times (int): The number of times to skip. Defaults to 0.
"""
def __init__(self, *args, skip_times=0, **kwargs):
super().__init__(*args, **kwargs)
self.skip_times = skip_times
def render(self, task: Task) -> Text:
"""Show time remaining."""
if task.completed <= self.skip_times:
return Text('-:--:--', style='progress.remaining')
return super().render(task)
def _tasks_with_index(tasks):
"""Add index to tasks."""
for idx, task in enumerate(tasks):
yield task, idx
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
nproc: int = 1,
chunksize: int = 1,
description: str = 'Processing',
save=None, keys=None,
color: str = 'blue') -> list:
"""Track the progress of parallel task execution with a progress bar. The
built-in :mod:`multiprocessing` module is used for process pools and tasks
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
Args:
func (callable): The function to be applied to each task.
tasks (Iterable or Sized): A tuple of tasks. There are several cases
for different format tasks:
- When ``func`` accepts no arguments: tasks should be an empty
tuple, and ``task_num`` must be specified.
- When ``func`` accepts only one argument: tasks should be a tuple
containing the argument.
- When ``func`` accepts multiple arguments: tasks should be a
tuple, with each element representing a set of arguments.
If an element is a ``dict``, it will be parsed as a set of
keyword-only arguments.
Defaults to an empty tuple.
task_num (int, optional): If ``tasks`` is an iterator which does not
have length, the number of tasks can be provided by ``task_num``.
Defaults to None.
nproc (int): Process (worker) number, if nuproc is 1,
use single process. Defaults to 1.
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
Defaults to 1.
description (str): The description of progress bar.
Defaults to "Process".
color (str): The color of progress bar. Defaults to "blue".
Examples:
>>> import time
>>> def func(x):
... time.sleep(1)
... return x**2
>>> track_progress_rich(func, range(10), nproc=2)
Returns:
list: The task results.
"""
if save is not None:
assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
if not osp.exists(save):
dump({}, save)
if keys is not None:
assert len(keys) == len(tasks)
if not callable(func):
raise TypeError('func must be a callable object')
if not isinstance(tasks, Iterable):
raise TypeError(
f'tasks must be an iterable object, but got {type(tasks)}')
if isinstance(tasks, Sized):
if len(tasks) == 0:
if task_num is None:
raise ValueError('If tasks is an empty iterable, '
'task_num must be set')
else:
tasks = tuple(tuple() for _ in range(task_num))
else:
if task_num is not None and task_num != len(tasks):
raise ValueError('task_num does not match the length of tasks')
task_num = len(tasks)
if nproc <= 0:
raise ValueError('nproc must be a positive number')
skip_times = nproc * chunksize if nproc > 1 else 0
prog_bar = Progress(
TextColumn('{task.description}'),
BarColumn(),
_SkipFirstTimeRemainingColumn(skip_times=skip_times),
MofNCompleteColumn(),
TaskProgressColumn(show_speed=True),
)
worker = _Worker(func)
task_id = prog_bar.add_task(
total=task_num, color=color, description=description)
tasks = _tasks_with_index(tasks)
# Use single process when nproc is 1, else use multiprocess.
with prog_bar:
if nproc == 1:
results = []
for task in tasks:
result, idx = worker(task)
results.append(worker(task)[0])
if save is not None:
with portalocker.Lock(save, timeout=5) as fh:
ans = load(save)
ans[keys[idx]] = result
if os.environ.get('VERBOSE', True):
print(keys[idx], result, flush=True)
dump(ans, save)
fh.flush()
os.fsync(fh.fileno())
prog_bar.update(task_id, advance=1, refresh=True)
else:
with Pool(nproc) as pool:
results = []
unordered_results = []
gen = pool.imap_unordered(worker, tasks, chunksize)
try:
for result in gen:
result, idx = result
unordered_results.append((result, idx))
if save is not None:
with portalocker.Lock(save, timeout=5) as fh:
ans = load(save)
ans[keys[idx]] = result
if os.environ.get('VERBOSE', False):
print(keys[idx], result, flush=True)
dump(ans, save)
fh.flush()
os.fsync(fh.fileno())
results.append(None)
prog_bar.update(task_id, advance=1, refresh=True)
except Exception as e:
prog_bar.stop()
raise e
for result, idx in unordered_results:
results[idx] = result
return results
import torch
torch.set_grad_enabled(False)
torch.manual_seed(1234)
from .base import BaseModel
from .minicpm_llama3_v_2_5 import MiniCPM_Llama3_V
from .minicpm_v import MiniCPM_V
\ No newline at end of file
from ..smp import *
from ..utils.dataset_config import img_root_map
from abc import abstractmethod
class BaseModel:
INTERLEAVE = False
allowed_types = ['text', 'image']
def use_custom_prompt(self, dataset):
"""Whether to use custom prompt for the given dataset.
Args:
dataset (str): The name of the dataset.
Returns:
bool: Whether to use custom prompt. If True, will call `build_prompt` of the VLM to build the prompt.
Default to False.
"""
return False
@abstractmethod
def build_prompt(self, line, dataset):
"""Build custom prompts for a specific dataset. Called only if `use_custom_prompt` returns True.
Args:
line (line of pd.DataFrame): The raw input line.
dataset (str): The name of the dataset.
Returns:
str: The built message.
"""
raise NotImplementedError
def dump_image(self, line, dataset):
"""Dump the image(s) of the input line to the corresponding dataset folder.
Args:
line (line of pd.DataFrame): The raw input line.
dataset (str): The name of the dataset.
Returns:
str | list[str]: The paths of the dumped images.
"""
ROOT = LMUDataRoot()
assert isinstance(dataset, str)
img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
os.makedirs(img_root, exist_ok=True)
if isinstance(line['image'], list):
tgt_path = []
assert 'image_path' in line
for img, im_name in zip(line['image'], line['image_path']):
path = osp.join(img_root, im_name)
if not read_ok(path):
decode_base64_to_image_file(img, path)
tgt_path.append(path)
else:
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
if not read_ok(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
tgt_path = [tgt_path]
return tgt_path
@abstractmethod
def generate_inner(self, message, dataset=None):
raise NotImplementedError
def check_content(self, msgs):
"""Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
"""
if isinstance(msgs, str):
return 'str'
if isinstance(msgs, dict):
return 'dict'
if isinstance(msgs, list):
types = [self.check_content(m) for m in msgs]
if all(t == 'str' for t in types):
return 'liststr'
if all(t == 'dict' for t in types):
return 'listdict'
return 'unknown'
def preproc_content(self, inputs):
"""Convert the raw input messages to a list of dicts.
Args:
inputs: raw input messages.
Returns:
list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
"""
if self.check_content(inputs) == 'str':
return [dict(type='text', value=inputs)]
elif self.check_content(inputs) == 'dict':
assert 'type' in inputs and 'value' in inputs
return [inputs]
elif self.check_content(inputs) == 'liststr':
res = []
for s in inputs:
mime, pth = parse_file(s)
if mime is None or mime == 'unknown':
res.append(dict(type='text', value=s))
else:
res.append(dict(type=mime.split('/')[0], value=pth))
return res
elif self.check_content(inputs) == 'listdict':
for item in inputs:
assert 'type' in item and 'value' in item
mime, s = parse_file(item['value'])
if mime is None:
assert item['type'] == 'text'
else:
assert mime.split('/')[0] == item['type']
item['value'] = s
return inputs
else:
return None
def generate(self, message, dataset=None):
"""Generate the output message.
Args:
message (list[dict]): The input message.
dataset (str, optional): The name of the dataset. Defaults to None.
Returns:
str: The generated message.
"""
assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
message = self.preproc_content(message)
assert message is not None and self.check_content(message) == 'listdict'
for item in message:
assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
return self.generate_inner(message, dataset)
def message_to_promptimg(self, message):
assert not self.INTERLEAVE
model_name = self.__class__.__name__
warnings.warn(
f'Model {model_name} does not support interleaved input. '
'Will use the first image and aggregated texts as prompt. ')
num_images = len([x for x in message if x['type'] == 'image'])
if num_images == 0:
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
image = None
else:
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
image = [x['value'] for x in message if x['type'] == 'image'][0]
return prompt, image
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from ..smp import *
from ..utils import DATASET_TYPE
from .base import BaseModel
class MiniCPM_Llama3_V(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self, model_path='openbmb/MiniCPM-V', **kwargs):
assert model_path is not None
self.model_path = model_path
self.ckpt = model_path
print(f'load from {self.model_path}')
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
if '.pt' in model_path:
print(f'load from {model_path}')
self.state_dict = torch.load(self.ckpt, map_location='cpu')
self.model.load_state_dict(self.state_dict, strict=False)
self.model = self.model.to(dtype=torch.float16)
self.model.eval().cuda()
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
torch.cuda.empty_cache()
self.num_beams = 1 if self.model_path == 'openbmb/MiniCPM-V' else 3
self.options_system_prompt = ('Carefully read the following question and select the letter corresponding '
'to the correct answer. Highlight the applicable choices without giving '
'explanations.')
self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
self.detail_system_prompt = 'Answer this question in detail.'
self.vqa_prompt = 'Answer the question using a single word or phrase.'
def use_custom_prompt(self, dataset):
if listinstr(['multi-choice', 'VQA'], DATASET_TYPE(dataset)):
return True
elif dataset is not None and listinstr(['HallusionBench'], dataset):
return True
return False
def build_prompt(self, line, dataset=None):
if dataset is None:
dataset = self.dataset
if isinstance(line, int):
line = self.data.iloc[line]
tgt_path = self.dump_image(line, dataset)
system_prompt = ''
question = line['question']
if DATASET_TYPE(dataset) == 'multi-choice':
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
system_prompt = self.options_system_prompt + "\nPlease just indicate your choice."
else:
system_prompt = self.wo_options_system_prompt
if 'MMMU' in dataset: # Corner Case
prompt = system_prompt + '\n' + prompt
system_prompt = ''
elif dataset is not None and listinstr(['HallusionBench'], dataset):
question = line['question'] + " Yes or No?"
prompt = question
elif dataset is not None and listinstr(['OCRBench'], dataset):
system_prompt = self.vqa_prompt
question = line['question']
prompt = question
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['LLaVABench'], dataset):
system_prompt = ""
prompt = question
elif listinstr(['MMVet'], dataset):
system_prompt = self.detail_system_prompt
prompt = question
else:
system_prompt = self.vqa_prompt
prompt = question
msgs = []
if system_prompt:
msgs.append(dict(type='text', value=system_prompt))
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs
def generate_inner(self, message, dataset=None):
if DATASET_TYPE(dataset) == 'multi-choice':
max_new_tokens = 200
elif DATASET_TYPE(dataset) == 'Y/N':
max_new_tokens = 3
else:
max_new_tokens = 1024
'''
nums_beams = 3
'''
default_kwargs = dict(
max_new_tokens=max_new_tokens,
sampling=False,
num_beams=self.num_beams,
)
default_kwargs.update(self.kwargs)
content = []
# message = [
# {'type': 'text', 'value': 'sys prompt'},
# {'type': 'image', 'value': '/path/to/image1.jpg'},
# {'type': 'text', 'value': 'Here is an image:'},
# ]
for x in message:
if x['type'] == 'text':
content.append(x['value'])
elif x['type'] == 'image':
image = Image.open(x['value']).convert('RGB')
content.append(image)
msgs = [{'role': 'user', 'content': content}]
res = self.model.chat(
image = None,
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
**default_kwargs
)
if isinstance(res, tuple) and len(res) > 0:
res = res[0]
# print(f"content: {content}, res: {res}")
return res
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from .base import BaseModel
from ..smp import *
from ..utils import DATASET_TYPE
class MiniCPM_V(BaseModel):
INSTALL_REQ = False
INTERLEAVE = False
def __init__(self, model_path='openbmb/MiniCPM-V', **kwargs):
assert model_path is not None
self.model_path = model_path
print(f'load from {self.model_path}')
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.to(dtype=torch.bfloat16)
self.model.eval().cuda()
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
torch.cuda.empty_cache()
self.num_beams = 1 if self.model_path == 'openbmb/MiniCPM-V' else 3
def use_custom_prompt(self, dataset):
assert dataset is not None
if listinstr(['MMMU'], dataset):
return True
return False
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'{question}\n'
if len(options):
prompt += options_prompt
prompt = 'Study the image carefully and pick the option associated with the correct answer. \
Focus solely on selecting the option and avoid including any other content.\n' + prompt
message = [dict(type='text', value=prompt)]
message.extend([dict(type='image', value=p) for p in tgt_path])
return message
def generate_inner(self, message, dataset=None):
prompt, image_path = self.message_to_promptimg(message)
image = Image.open(image_path).convert('RGB')
msgs = [{'role': 'user', 'content': prompt}]
if DATASET_TYPE(dataset) == 'multi-choice':
max_new_tokens = 20
elif DATASET_TYPE(dataset) == 'Y/N':
max_new_tokens = 100
else:
max_new_tokens = 1024
default_kwargs = dict(
max_new_tokens=max_new_tokens,
sampling=False,
num_beams=self.num_beams
)
default_kwargs.update(self.kwargs)
res, _, _ = self.model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
**default_kwargs
)
return res
# vqa-eval
contains vqa_eval kit from the server.
import json
import os
import re
from torch.utils.data import Dataset
def prompt_processor(prompt):
if prompt.startswith('OCR tokens: '):
pattern = r"Question: (.*?) Short answer:"
match = re.search(pattern, prompt, re.DOTALL)
question = match.group(1)
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
if prompt.startswith('Reference OCR token:'):
question = prompt.split('\n')[1]
else:
question = prompt.split('\n')[0]
elif len(prompt.split('\n')) == 2:
question = prompt.split('\n')[0]
else:
assert False
return question.lower()
class textVQADataset(Dataset):
def __init__(
self,
image_dir="./downloads/TextVQA/train_images",
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir = image_dir
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question = self.data[idx]['question']
answers = self.data[idx]['answers']
img_id = self.data[idx]['image_id']
qid = self.data[idx]['question_id']
img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
item = {
"question_id": qid,
"image_path": img_path,
"question": question,
"gt_answers": answers
}
return item
class docVQADataset(Dataset):
def __init__(
self,
image_dir= "./downloads/DocVQA/spdocvqa_images",
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
ocr_token_path=None
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir = image_dir
self.ann_path = ann_path
if ocr_token_path:
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question_id = self.data[idx]['questionId']
relative_img_path = self.data[idx]['image']
corrected_relative_img_path = relative_img_path.replace("documents", "images")
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
question = self.data[idx]['question']
answers = self.data[idx]['answers']
question_type = self.data[idx]['question_types']
return {
"question_id": question_id,
"image_path": img_path,
"question": question,
"gt_answers": answers,
'question_type': question_type,
}
class docVQATESTDataset(Dataset):
def __init__(
self,
image_dir= "./downloads/DocVQA/spdocvqa_images",
ann_path= "./downloads/DocVQA/test_v1.0.json",
ocr_token_path=None
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir = image_dir
self.ann_path = ann_path
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question_id = self.data[idx]['questionId']
relative_img_path = self.data[idx]['image']
corrected_relative_img_path = relative_img_path.replace("documents", "images")
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
question = self.data[idx]['question']
return {
"question_id": question_id,
"image_path": img_path,
"question": question,
"gt_answers": "",
'question_type': "",
}
import sys
import datetime
import json
import os
import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, '..'))
from datasets.vqa_dataset import docVQADataset, docVQATESTDataset, textVQADataset
print(torch.__version__)
import numpy as np
from eval_utils.getargs import parse_args
from eval_utils.vqa_evaluate import *
def get_model(args):
if args.model_name=='':
raise Exception('Model name cannot be empty str!')
from models.MiniCPM.minicpmv import MiniCPM_V
model_path = args.model_path
ckpt = args.ckpt
model = MiniCPM_V(model_path=model_path, ckpt=ckpt, device=args.device)
return model
def main(args):
np.random.seed(0)
max_sample_num = None
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
print(f'Init Rank-{torch.distributed.get_rank()}')
if torch.distributed.is_initialized():
args.device = torch.device(f"cuda:{torch.cuda.current_device()}")
model = get_model(args)
result = {}
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
if args.eval_textVQA or args.eval_all:
dataset = textVQADataset(args.textVQA_image_dir, args.textVQA_ann_path)
if max_sample_num is not None:
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time, \
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
result['textVQA'] = acc
if args.eval_docVQA or args.eval_all:
dataset = docVQADataset(args.docVQA_image_dir, args.docVQA_ann_path)
if max_sample_num is not None:
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
result['docVQA'] = acc
if args.eval_docVQATest or args.eval_all:
target_dataset = "docVQATest"
dataset = docVQATESTDataset(args.docVQATest_image_dir, args.docVQATest_ann_path)
if max_sample_num is not None:
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
acc = evaluate_VQA(model, dataset, args.model_name, target_dataset, time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
result['docVQATest'] = acc
if torch.distributed.is_initialized():
torch.distributed.barrier()
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return None
result_path = os.path.join(os.path.join(args.answer_path, args.model_name), 'result.json')
output_flag = False
for k, v in result.items():
if v > 0.0:
output_flag = True
break
if output_flag:
with open(result_path, "w") as f:
f.write(json.dumps(result, indent=4))
if __name__ == "__main__":
args = parse_args()
main(args)
\ No newline at end of file
import json
import glob
import re
def has_word(sentence, word):
pattern = r"\b" + re.escape(word) + r"\b"
match = re.search(pattern, sentence)
if match:
return True
else:
return False
def remove_special_chars(s):
pattern = r"[^a-zA-Z0-9\s]"
s = re.sub(pattern, "", s)
return s
for model in glob.glob('./answer_save/*'):
print(model, ':')
result_list = sorted(glob.glob(f'{model}/*.json'))
for task_result_path in result_list:
taskname = task_result_path.split('/')[-1]
taskname = taskname.split('.')[0]
if taskname not in ['IIIT5K', 'svt', 'IC13_857', 'IC15_1811', 'svtp', 'ct80',
'cocotext', 'ctw', 'totaltext', 'HOST']:
continue
correct = 0
num = 0
with open(task_result_path, 'r') as f:
dict = json.load(f)[:100]
for i in range(len(dict)):
gt_answers = dict[i]['gt_answers']
answer = dict[i]['answer']
gt_answers = remove_special_chars(gt_answers).lower()
answer = remove_special_chars(answer).lower()
if has_word(answer, gt_answers):
correct+=1
num+=1
print(f'{taskname:10s}:{float(correct)/num*100:.2f}')
print('=' * 32)
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument('--local-rank', type=int, default=0, help='Local rank for distributed training')
# textVQA
parser.add_argument("--textVQA_image_dir", type=str, default="")
parser.add_argument("--textVQA_ann_path", type=str, default="")
# docVQA
parser.add_argument("--docVQA_image_dir", type=str, default="")
parser.add_argument("--docVQA_ann_path", type=str, default="")
# docVQATest
parser.add_argument("--docVQATest_image_dir", type=str, default="")
parser.add_argument("--docVQATest_ann_path", type=str, default="")
# result path
parser.add_argument("--answer_path", type=str, default="./answers-new")
# eval
parser.add_argument(
"--eval_textVQA",
action="store_true",
default=False,
help="Whether to evaluate on textVQA."
)
parser.add_argument(
"--eval_docVQA",
action="store_true",
default=False,
help="Whether to evaluate on docVQA."
)
parser.add_argument(
"--eval_docVQATest",
action="store_true",
default=False,
help="Whether to evaluate on docVQA."
)
parser.add_argument(
"--eval_all",
action="store_true",
default=False,
help="Whether to evaluate all datasets"
)
parser.add_argument("--model_name", type=str, default="")
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--generate_method", type=str, default="", help="generate with interleave or not.")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument('--batchsize', type=int, default=1, help='Batch size for processing.')
parser.add_argument("--ckpt", type=str, default="")
args = parser.parse_args()
return args
\ No newline at end of file
import itertools
import json
import os
import re
from collections import namedtuple
import torch
from tqdm import tqdm
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def collate_fn_vqa(batches):
'''
'''
image_paths = [_['image_path'] for _ in batches]
questions = [_['question'] for _ in batches]
gt_answers = [_['gt_answers'] for _ in batches]
ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches]
question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches]
question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches]
return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type
def has_word(sentence, word):
if word[0].isalnum():
start_pattern = r"\b"
else:
start_pattern = r""
if word[-1].isalnum():
end_pattern = r"\b"
else:
end_pattern = r""
pattern = start_pattern + re.escape(word) + end_pattern
match = re.search(pattern, sentence)
return bool(match)
def remove_special_chars(s):
pattern = r"[^a-zA-Z0-9\s]"
s = re.sub(pattern, "", s)
return s
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
class VQAEval:
def __init__(self):
self.contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
self.manualMap = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
self.articles = ["a", "an", "the"]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def clean_text(self, text):
text = text.replace("\n", " ").replace("\t", " ").strip()
text = self.processPunctuation(text)
text = self.processDigitArticle(text)
return text
def evaluate_vqa_human(self, answer, gt_answers):
'''TextVQA, VQAv2, OKVQA, vizwiz'''
answer = answer.replace("\n", " ").replace("\t", " ").strip()
answer = self.processPunctuation(answer)
answer = self.processDigitArticle(answer)
gt_answers = [self.processPunctuation(ans) for ans in gt_answers]
gt_answers = [self.processDigitArticle(ans) for ans in gt_answers]
gtAcc = []
for idx, gtAnsDatum in enumerate(gt_answers):
otherGTAns = gt_answers[:idx] + gt_answers[idx+1:]
matchingAns = [item for item in otherGTAns if answer == item]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0
return avgGTAcc
def evaluate_anls(self, answer, gt_answers, threshold=0.5):
'''DOcVQA, InfographicsVQA, STVQA'''
answer = ' '.join(answer.strip().lower().split())
if not isinstance(gt_answers, list):
gt_answers = [gt_answers]
gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers]
values = []
for gt_answer in gt_answers:
dist = levenshtein_distance(answer, gt_answer)
length = max(len(answer), len(gt_answer))
values.append(0.0 if length == 0 else float(dist) / float(length))
score = 1 - min(values)
score = 0 if score < threshold else score
return score
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = " ".join(outText)
return outText
def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None):
with open(answer_file_path, 'r', encoding='utf-8') as f:
predictions = json.load(f)
eval = VQAEval()
total_accuracy = 0
num = 0
Entry = namedtuple('Entry', ['text', 'bbox'])
for item in predictions:
gt_answers = item['gt_answers']
answer = item['answer']
if method is not None:
pass
if dataset_name in ["textVQA"]:
if num == 0:
print(f"evaluating vqa...")
accuracy = eval.evaluate_vqa_human(answer, gt_answers)
elif dataset_name in ['docVQA']:
if num == 0:
print(f"evaluating anls...")
accuracy = eval.evaluate_anls(answer, gt_answers)
else:
accuracy = eval.evaluate_has(answer, gt_answers)
item['accuracy'] = accuracy
total_accuracy += accuracy
num += 1
average_accuracy = total_accuracy / num
print(f'{dataset_name}:{average_accuracy}')
answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json')
with open(answer_model_method_path, "w", encoding='utf-8') as f:
json.dump(predictions, f, indent=4, ensure_ascii=False)
return average_accuracy
def evaluate_VQA(
model,
dataset,
model_name,
dataset_name,
time,
batch_size=1,
generate_method="interleave",
answer_path='./answers',
):
print(f"answer path:{answer_path}")
sampler = None
if torch.distributed.is_initialized():
sampler=InferenceSampler(len(dataset))
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=collate_fn_vqa
)
now_rank = torch.distributed.get_rank()
answer_dir = os.path.join(answer_path, model_name, time)
os.makedirs(answer_dir, exist_ok=True)
image_list = []
for item in dataset:
image_list.append(item["image_path"])
predictions = []
for batch in tqdm(dataloader, desc="Running inference"):
image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch
with torch.no_grad():
if model_name != "minicpm":
if model_name != "codellama":
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
else:
outputs = model.generate()
elif model_name == "minicpm":
if generate_method == "old":
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
elif generate_method == "interleave":
outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name)
else:
raise Exception(f"Wrong generate paradigm {generate_method}!")
for i in range(len(outputs)):
answer_dict = {
'question_id': question_ids[i],
'question': questions[i],
'answer': outputs[i],
'gt_answers': gt_answers[i],
'image_path': image_paths[i],
'model_name': model_name,
'question_type': question_type[i]
}
predictions.append(answer_dict)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
merged_predictions = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_predictions, predictions)
predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)]
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return None
answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json")
print(f"answer_file_path:{answer_file_path}")
with open(answer_file_path, "w", encoding='utf-8') as f:
json.dump(predictions, f, indent=4, ensure_ascii=False)
if dataset_name in ["docVQATest"]:
return -1.0
return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name)
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
Image.MAX_IMAGE_PIXELS = 1000000000
max_token = {
'docVQA': 100,
'textVQA': 100,
"docVQATest": 100
}
class MiniCPM_V:
def __init__(self, model_path, ckpt, device=None)->None:
self.model_path = model_path
self.ckpt = ckpt
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).eval()
if self.ckpt is not None:
self.ckpt = ckpt
self.state_dict = torch.load(self.ckpt, map_location=torch.device('cpu'))
self.model.load_state_dict(self.state_dict)
self.model = self.model.to(dtype=torch.float16)
self.model.to(device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
torch.cuda.empty_cache()
def generate(self, images, questions, datasetname):
image = Image.open(images[0]).convert('RGB')
try:
max_new_tokens = max_token[datasetname]
except:
max_new_tokens = 1024
if (datasetname == 'docVQA') or (datasetname == "docVQATest") :
prompt = "Answer the question directly with single word." + "\n" + questions[0]
elif (datasetname == 'textVQA') :
prompt = "Answer the question directly with single word." + '\n'+ questions[0]
msgs = [{'role': 'user', 'content': prompt}]
default_kwargs = dict(
max_new_tokens=max_new_tokens,
sampling=False,
num_beams=3
)
res = self.model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
**default_kwargs
)
return [res]
def generate_with_interleaved(self, images, questions, datasetname):
try:
max_new_tokens = max_token[datasetname]
except:
max_new_tokens = 1024
prompt = "Answer the question directly with single word."
default_kwargs = dict(
max_new_tokens=max_new_tokens,
sampling=False,
num_beams=3
)
content = []
message = [
{'type': 'text', 'value': prompt},
{'type': 'image', 'value': images[0]},
{'type': 'text', 'value': questions[0]}
]
for x in message:
if x['type'] == 'text':
content.append(x['value'])
elif x['type'] == 'image':
image = Image.open(x['value']).convert('RGB')
content.append(image)
msgs = [{'role': 'user', 'content': content}]
res = self.model.chat(
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
**default_kwargs
)
if isinstance(res, tuple) and len(res) > 0:
res = res[0]
print(f"Q: {content}, \nA: {res}")
return [res]
accelerate
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==22.2.0
bitsandbytes==0.37.0
cchardet==2.1.7
chardet==5.1.0
contourpy==1.0.7
cycler==0.11.0
filelock==3.9.0
fonttools==4.38.0
frozenlist==1.3.3
huggingface-hub==0.13.4
importlib-resources==5.12.0
kiwisolver==1.4.4
matplotlib==3.7.0
multidict==6.0.4
openai==0.27.0
packaging==23.0
psutil==5.9.4
pycocotools==2.0.6
pyparsing==3.0.9
python-dateutil==2.8.2
pyyaml==6.0
regex==2022.10.31
tokenizers==0.13.2
tqdm==4.64.1
transformers
timm==0.6.13
spacy==3.5.1
webdataset==0.2.48
scikit-learn==1.2.2
scipy==1.10.1
yarl==1.8.2
zipp==3.14.0
omegaconf==2.3.0
opencv-python==4.7.0.72
iopath==0.1.10
decord==0.6.0
tenacity==8.2.2
peft
pycocoevalcap
sentence-transformers
umap-learn
notebook
gradio==3.24.1
gradio-client==0.0.8
wandb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment