Commit 51fea27f authored by lvskiller's avatar lvskiller
Browse files

model

parent c9fa4205
from argparse import ArgumentParser
from pathlib import Path
import copy
import gradio as gr
import os
import re
import secrets
import tempfile
from PIL import Image
from monkey_model.modeling_monkey import MonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
from monkey_model.configuration_monkey import MonkeyConfig
import shutil
from pathlib import Path
import json
DEFAULT_CKPT_PATH = '/home/zhangli/demo/'
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
title_markdown = ("""
# Welcome to Monkey
Hello! I'm Monkey, a Large Language and Vision Assistant. Before talking to me, please read the **Operation Guide** and **Terms of Use**.
你好!我是Monkey,一个大型语言和视觉助理。在与我交谈之前,请阅读**操作指南**和**使用条款**。
## Operation Guide 操作指南
Click the **Upload** button to upload an image. Then, you can get Monkey's answer in two ways:点击**Upload**上传图像。你可以通过两种方式得到Monkey的回答:
- Click the **Generate** and Monkey will generate a description of the image. 点击**Generate**,Monkey将生成图像的描述。
- Enter the question in the dialog box, click the **Submit**, and Monkey will answer the question based on the image. 在对话框中输入问题,点击**Submit**,Monkey会根据图片回答问题。
- Click **Clear History** to clear the current image and Q&A content.点击**Clear History**,清除当前图片和问答内容。
> Note: Monkey does not have a multi-round dialogue function. Perhaps we will further develop its capabilities in the future. 注意:Monkey没有多轮对话功能,或许我们在未来会进一步开发它的能力。
> Monkey支持中文,但使用英文提问会比使用中文效果明显好.""")
policy_markdown = ("""
## Terms of Use
By using this service, users are required to agree to the following terms:
- Monkey is for research use only and unauthorized commercial use is prohibited. For any query, please contact the author.
- Monkey's generation capabilities are limited, so we recommend that users do not rely entirely on its answers.
- Monkey's security measures are limited, so we cannot guarantee that the output is completely appropriate. We strongly recommend that users do not intentionally guide Monkey to generate harmful content, including hate speech, discrimination, violence, pornography, deception, etc.
""")
def _get_args():
parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--share", action="store_true", default=False,
help="Create a publicly shareable link for the interface.")
parser.add_argument("--inbrowser", action="store_true", default=False,
help="Automatically launch the interface in a new tab on the default browser.")
parser.add_argument("--server-port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
args = parser.parse_args()
return args
def _load_model_tokenizer(args):
tokenizer = QWenTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True)
if args.cpu_only:
device_map = "cpu"
else:
device_map = "cuda"
model = MonkeyLMHeadModel.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
).eval()
# model.generation_config = GenerationConfig.from_pretrained(
# args.checkpoint_path, trust_remote_code=True, resume_download=True,
# )
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
return model, tokenizer
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def _launch_demo(args, model, tokenizer):
def predict(_chatbot, task_history):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
question = _parse_text(query)
print("User: " + _parse_text(query))
full_response = ""
img_path = _chatbot[0][0][0]
try:
Image.open(img_path)
except:
response = "Please upload a picture."
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print("Monkey: " + _parse_text(full_response))
return _chatbot
query = f'<img>{img_path}</img> {question} Answer: '
print(query)
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=3,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print("Monkey: " + _parse_text(full_response))
return _chatbot
def caption(_chatbot, task_history):
query = "Generate the detailed caption in English:"
chat_query = "Generate the detailed caption in English:"
question = _parse_text(query)
print("User: " + _parse_text(query))
full_response = ""
try:
img_path = _chatbot[0][0][0]
Image.open(img_path)
except:
response = "Please upload a picture."
_chatbot.append((None, response))
full_response = _parse_text(response)
task_history.append((None, full_response))
print("Monkey: " + _parse_text(full_response))
return _chatbot
img_path = _chatbot[0][0][0]
query = f'<img>{img_path}</img> {chat_query} '
print(query)
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=True,
temperature=0.7,
max_new_tokens=250,
min_new_tokens=1,
length_penalty=3,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
_chatbot.append((None, response))
full_response = _parse_text(response)
task_history.append((None, full_response))
print("Monkey: " + _parse_text(full_response))
return _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
print(history, task_history, text)
return history, task_history, ""
def add_file(history, task_history, file):
history = [((file.name,), None)]
task_history = [((file.name,), None)]
print(history, task_history, file)
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
with gr.Blocks() as demo:
gr.Markdown(title_markdown)
chatbot = gr.Chatbot(label='Monkey', elem_classes="control-height", height=600,avatar_images=("https://ooo.0x0.ooo/2023/11/09/OehsLx.png","https://ooo.0x0.ooo/2023/11/09/OehGBC.png"),layout="bubble",bubble_full_width=False,show_copy_button=True)
query = gr.Textbox(lines=1, label='Input')
task_history = gr.State([])
with gr.Row():
empty_bin = gr.Button("Clear History (清空)")
submit_btn = gr.Button("Submit (提问)")
generate_btn_en = gr.Button("Generate")
addfile_btn = gr.UploadButton("Upload (上传图片)", file_types=["image"])
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
generate_btn_en.click(caption, [chatbot, task_history], [chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True,scroll_to_output=True)
gr.Markdown(policy_markdown)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7681
)
def main():
args = _get_args()
model, tokenizer = _load_model_tokenizer(args)
_launch_demo(args, model, tokenizer)
if __name__ == '__main__':
main()
EVAL_PTH=$1
SAVE_NAME=$2
python -m torch.distributed.launch --use-env --nproc_per_node ${NPROC_PER_NODE:-1} --nnodes ${WORLD_SIZE:-1} --node_rank ${RANK:-0} --master_addr ${MASTER_ADDR:-127.0.0.1} --master_port ${MASTER_PORT:-12345} eval/evaluate_vqa.py --checkpoint $EVAL_PTH --batch-size 4 --num-workers 2 --save_name $SAVE_NAME
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
from typing import Optional
import sys
import torch
from tqdm import tqdm
from vqa import VQA
from vqa_eval import VQAEval
sys.path.append("pathto/Monkey/")
from monkey_model.modeling_monkey import MonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
import numpy as np
from pathlib import Path
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
ds_collections = {
'estvqa_test': {
'train': 'data/ESTVQA/estvqa.jsonl',
'test': 'data/estvqa/estvqa.jsonl',
'metric': 'anls',
'max_new_tokens': 100,
},
'docvqa_test': {
'train': 'data/docvqa/train.jsonl',
'test': 'data/docvqa/test_ans.jsonl',
'metric': 'anls',
'max_new_tokens': 100,
},
'chartqa_ureader': {
'train': 'data/chartqa/train_augmented.jsonl',
'test': 'data/chartqa/chartqa_ureader.jsonl',
'metric': 'relaxed_accuracy',
'max_new_tokens': 100,
},
'infovqa_test': {
'train': 'data/infographicVQA/infovqa.jsonl',
'test': 'data/infographicVQA/infovqa_test.jsonl',
'metric': 'anls',
'max_new_tokens': 100,
},
'vizwiz_val': {
'train': 'data/vizwiz/vizwiz_train.jsonl',
'test': 'data/vizwiz/vizwiz_val.jsonl',
'question': 'data/vizwiz/vizwiz_val_questions.json',
'annotation': 'data/vizwiz/vizwiz_val_annotations.json',
'metric': 'vqa_score',
'max_new_tokens': 10,
},
'deepform': {
'train': '',
'test': 'data/ureader/test_DeepForm.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'KLC': {
'train': '',
'test': 'data/ureader/test_KleisterCharity.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'WTQ': {
'train': '',
'test': 'data/ureader/test_WikiTableQuestions.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'gqa_testdev': {
'train': 'data/gqa/train.jsonl',
'test': 'data/gqa/gqa_testdev_new.json',
'metric': 'accuracy',
'max_new_tokens': 10,
},
'okvqa_val': {
'train': 'data/okvqa/okvqa_train.jsonl',
'test': 'data/okvqa/okvqa_val.jsonl',
'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json',
'annotation': 'data/okvqa/mscoco_val2014_annotations.json',
'metric': 'vqa_score',
'max_new_tokens': 10,
},
'textvqa_val': {
'train': 'data/textvqa/textvqa_train.jsonl',
'test': 'data/textvqa/textvqa_val.jsonl',
'question': 'data/textvqa/textvqa_val_questions.json',
'annotation': 'data/textvqa/textvqa_val_annotations.json',
'metric': 'vqa_score',
'max_new_tokens': 10,
},
'stvqa_test': {
'train': 'data/STVQA/stvqa.jsonl',
'test': 'data/STVQA/stvqa.jsonl',
'metric': 'anls',
'max_new_tokens': 100,
},
'ai2diagram_test': {
'train': 'data/ai2d/train.jsonl',
'test': 'data/ai2d/test.jsonl',
'metric': 'accuracy',
'max_new_tokens': 10,
},
'vqav2_val': {
'train': 'data/vqav2/vqav2_train.jsonl',
'test': 'data/vqav2/vqav2_val.jsonl',
'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json',
'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json',
'metric': 'vqa_score',
'max_new_tokens': 10,
},
}
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def normANLS(s1,s2):
dist = levenshtein_distance(s1.lower().strip(),s2.lower().strip())
length = max(len(s1),len(s2))
value = 0.0 if length == 0 else float(dist) / float(length)
return value
def evaluateANLS(ans_list):
anls_threshold = 0.5
anls_list = []
for predict_pair in ans_list:
answer = predict_pair["answer"].strip()
gt_list = predict_pair["annotation"]
value_list = []
for gt_single in gt_list:
value_list.append(normANLS(gt_single,answer))
question_result = 1 - min(value_list)
if (question_result < anls_threshold) :
question_result = 0
anls_list.append(question_result)
return np.mean(anls_list)
# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
prediction: str,
max_relative_change: float = 0.05) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
Args:
target: Target string.
prediction: Predicted string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
def _to_float(text: str) -> Optional[float]:
try:
if text.endswith('%'):
# Convert percentages to floats.
return float(text.rstrip('%')) / 100.0
else:
return float(text)
except ValueError:
return None
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float -
target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
def evaluate_relaxed_accuracy(entries):
scores = []
for elem in entries:
if isinstance(elem['annotation'], str):
elem['annotation'] = [elem['annotation']]
score = max([
relaxed_correctness(elem['answer'].strip(), ann)
for ann in elem['annotation']
])
scores.append(score)
return sum(scores) / len(scores)
def evaluate_exact_match_accuracy(entries):
scores = []
for elem in entries:
if isinstance(elem['annotation'], str):
elem['annotation'] = [elem['annotation']]
score = max([
(1.0 if
(ann.strip().lower() in elem['answer'].strip().lower().replace(".","") ) else 0.0)
for ann in elem['annotation']
])
scores.append(score)
return sum(scores) / len(scores)
def collate_fn(batches, tokenizer):
image_paths = [_['image_path'] for _ in batches]
questions = [_['question'] for _ in batches]
question_ids = [_['question_id'] for _ in batches]
annotations = [_['annotation'] for _ in batches]
input_ids = tokenizer(questions, return_tensors='pt', padding='longest')
return image_paths,question_ids, input_ids.input_ids, input_ids.attention_mask, annotations
class VQADataset(torch.utils.data.Dataset):
def __init__(self, train, test, prompt, few_shot):
self.test = open(test).readlines()
self.prompt = prompt
self.few_shot = few_shot
if few_shot > 0:
self.train = open(train).readlines()
def __len__(self):
return len(self.test)
def __getitem__(self, idx):
data = json.loads(self.test[idx].strip())
image, question, question_id, annotation = data['image'], data[
'question'], data['question_id'], data.get('answer', None)
few_shot_prompt = ''
if self.few_shot > 0:
few_shot_samples = random.sample(self.train, self.few_shot)
for sample in few_shot_samples:
sample = json.loads(sample.strip())
few_shot_prompt += self.prompt.format(
sample['image'],
sample['question']) + f" {sample['answer']}"
return {
'image_path':image,
'question': few_shot_prompt + self.prompt.format(image, question),
'question_id': question_id,
'annotation': annotation
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def evaluate(model,tokenizer,prompt,args,dataset_name):
dataset_info = ds_collections[dataset_name]
dataset = VQADataset(
train=dataset_info['train'],
test=dataset_info['test'],
prompt=prompt,
few_shot=args.few_shot,
)
len_dataset = len(dataset)
if torch.distributed.get_rank() == 0:
print(f"there have {len(dataset)} in {dataset_name}")
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
outputs = []
for image_paths,question_ids, input_ids, attention_mask,annotations in tqdm(dataloader):
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=dataset_info['max_new_tokens'],
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
answers = [
tokenizer.decode(_[input_ids.size(1):].cpu(),
skip_special_tokens=True).strip() for _ in pred
]
for image_path,question_id, answer, annotation in zip(image_paths,question_ids, answers,
annotations):
if dataset_name in ['vqav2_val', 'okvqa_val', 'textvqa_val', 'vizwiz_val']:
outputs.append({
'image_path':image_path,
'question_id': question_id,
'answer': answer,
})
elif dataset_name in ['docvqa_test', 'gqa_testdev',"stvqa_test","infovqa_test"]:
outputs.append({
'image_path':image_path,
'questionId': question_id,
'answer': answer,
'annotation': annotation,
})
elif dataset_name in ['ai2diagram_test',"WTQ","deepform","KLC"]:
outputs.append({
'image_path':image_path,
'image': question_id,
'answer': answer,
'annotation': annotation,
})
elif dataset_name in ['estvqa_test']:
outputs.append({
'image_path':image_path,
'questionId': question_id,
'answer': answer,
'annotation': [annotation],
})
elif dataset_name in ["chartqa_ureader"]:
outputs.append({
'image_path':image_path,
'answer': answer,
'annotation': annotation,
})
else:
raise NotImplementedError
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
merged_outputs = [json.loads(_) for _ in merged_outputs]
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {dataset_name} ...")
results_file = f'{dataset_name}.json'
root_path = os.path.join("result",args.save_name)
Path(root_path).mkdir(exist_ok=True,parents=True)
results_file = os.path.join(root_path,results_file)
json.dump(merged_outputs, open(results_file, 'w',encoding="utf-8"), ensure_ascii=False,indent=2)
if dataset_info['metric'] == 'vqa_score':
vqa = VQA(dataset_info['annotation'],dataset_info['question'])
results = vqa.loadRes(
resFile=results_file,
quesFile=dataset_info['question'])
vqa_scorer = VQAEval(vqa, results, n=2)
question_id_list = [item["question_id"]for item in merged_outputs]
vqa_scorer.evaluate(question_id_list)
print(vqa_scorer.accuracy)
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(vqa_scorer.accuracy["overall"])+'\n')
elif dataset_info['metric'] == 'anls':
json.dump(merged_outputs,
open(results_file, 'w'),
ensure_ascii=False)
anls_res = evaluateANLS(merged_outputs)
print(anls_res)
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(anls_res)+'\n')
elif dataset_info['metric'] == 'relaxed_accuracy':
print({
'relaxed_accuracy': evaluate_relaxed_accuracy(merged_outputs)
})
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(evaluate_relaxed_accuracy(merged_outputs))+'\n')
elif dataset_info['metric'] == 'accuracy':
if 'gqa' in dataset_name:
for entry in merged_outputs:
response = entry['answer']
response = response.strip().split('.')[0].split(
',')[0].split('!')[0].lower()
if 'is ' in response:
response = response.split('is ')[1]
if 'are ' in response:
response = response.split('are ')[1]
if 'a ' in response:
response = response.split('a ')[1]
if 'an ' in response:
response = response.split('an ')[1]
if 'the ' in response:
response = response.split('the ')[1]
if ' of' in response:
response = response.split(' of')[0]
response = response.strip()
entry['answer'] = response
acc = evaluate_exact_match_accuracy(merged_outputs)
print({'accuracy': acc})
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(acc)+'\n')
torch.distributed.barrier()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--few-shot', type=int, default=0)
parser.add_argument('--seed', type=int, default=3407)
parser.add_argument("--save_name",type=str,default="test")
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = MonkeyLMHeadModel.from_pretrained(
args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
tokenizer = QWenTokenizer.from_pretrained(args.checkpoint,
trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
random.seed(args.seed)
for k,_ in ds_collections.items():
if "vizwiz_val" in k:
prompt = '<img>{}</img> {} When the provided information is insufficient, respond with "Unanswerable". Answer:'
else:
prompt = '<img>{}</img>{} Answer:'
evaluate(model,tokenizer,prompt,args,k)
"""Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
__author__ = 'aagrawal'
__version__ = '0.9'
# Interface for accessing the VQA dataset.
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
# The following functions are defined:
# VQA - VQA class that loads VQA annotation file and prepares data structures.
# getQuesIds - Get question ids that satisfy given filter conditions.
# getImgIds - Get image ids that satisfy given filter conditions.
# loadQA - Load questions and answers with the specified question ids.
# showQA - Display the specified questions and answers.
# loadRes - Load result file and create result object.
# Help on each function can be accessed by: "help(COCO.function)"
import copy
import datetime
import json
class VQA:
def __init__(self, annotation_file=None, question_file=None):
"""Constructor of VQA helper class for reading and visualizing
questions and answers.
:param annotation_file (str): location of VQA annotation file
:return:
"""
# load dataset
self.dataset = {}
self.questions = {}
self.qa = {}
self.qqa = {}
self.imgToQA = {}
if not annotation_file == None and not question_file == None:
print('loading VQA annotations and questions into memory...')
time_t = datetime.datetime.utcnow()
dataset = json.load(open(annotation_file, 'r'))
questions = json.load(open(question_file, 'r'))
self.dataset = dataset
self.questions = questions
self.createIndex()
def createIndex(self):
# create index
print('creating index...')
imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
for ann in self.dataset['annotations']:
imgToQA[ann['image_id']] += [ann]
qa[ann['question_id']] = ann
for ques in self.questions['questions']:
qqa[ques['question_id']] = ques
print('index created!')
# create class members
self.qa = qa
self.qqa = qqa
self.imgToQA = imgToQA
def info(self):
"""Print information about the VQA annotation file.
:return:
"""
for key, value in self.datset['info'].items():
print('%s: %s' % (key, value))
def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
"""Get question ids that satisfy given filter conditions. default skips
that filter.
:param imgIds (int array) : get question ids for given imgs
quesTypes (str array) : get question ids for given question types
ansTypes (str array) : get question ids for given answer types
:return: ids (int array) : integer array of question ids
"""
imgIds = imgIds if type(imgIds) == list else [imgIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
anns = sum(
[
self.imgToQA[imgId]
for imgId in imgIds if imgId in self.imgToQA
],
[],
)
else:
anns = self.dataset['annotations']
anns = (anns if len(quesTypes) == 0 else
[ann for ann in anns if ann['question_type'] in quesTypes])
anns = (anns if len(ansTypes) == 0 else
[ann for ann in anns if ann['answer_type'] in ansTypes])
ids = [ann['question_id'] for ann in anns]
return ids
def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
"""Get image ids that satisfy given filter conditions. default skips
that filter.
:param quesIds (int array) : get image ids for given question ids
quesTypes (str array) : get image ids for given question types
ansTypes (str array) : get image ids for given answer types
:return: ids (int array) : integer array of image ids
"""
quesIds = quesIds if type(quesIds) == list else [quesIds]
quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
anns = self.dataset['annotations']
else:
if not len(quesIds) == 0:
anns = sum([
self.qa[quesId] for quesId in quesIds if quesId in self.qa
], [])
else:
anns = self.dataset['annotations']
anns = (anns if len(quesTypes) == 0 else
[ann for ann in anns if ann['question_type'] in quesTypes])
anns = (anns if len(ansTypes) == 0 else
[ann for ann in anns if ann['answer_type'] in ansTypes])
ids = [ann['image_id'] for ann in anns]
return ids
def loadQA(self, ids=[]):
"""Load questions and answers with the specified question ids.
:param ids (int array) : integer ids specifying question ids
:return: qa (object array) : loaded qa objects
"""
if type(ids) == list:
return [self.qa[id] for id in ids]
elif type(ids) == int:
return [self.qa[ids]]
def showQA(self, anns):
"""Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
for ann in anns:
quesId = ann['question_id']
print('Question: %s' % (self.qqa[quesId]['question']))
for ans in ann['answers']:
print('Answer %d: %s' % (ans['answer_id'], ans['answer']))
def loadRes(self, resFile, quesFile):
"""Load result file and return a result object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = VQA()
res.questions = json.load(open(quesFile))
res.dataset['info'] = copy.deepcopy(self.questions['info'])
res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
res.dataset['data_subtype'] = copy.deepcopy(
self.questions['data_subtype'])
res.dataset['license'] = copy.deepcopy(self.questions['license'])
print('Loading and preparing results... ')
time_t = datetime.datetime.utcnow()
anns = json.load(open(resFile))
assert type(anns) == list, 'results is not an array of objects'
annsQuesIds = [ann['question_id'] for ann in anns]
assert set(annsQuesIds) == set(
self.getQuesIds()
), 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
for ann in anns:
quesId = ann['question_id']
if res.dataset['task_type'] == 'Multiple Choice':
assert (
ann['answer'] in self.qqa[quesId]['multiple_choices']
), 'predicted answer is not one of the multiple choices'
qaAnn = self.qa[quesId]
ann['image_id'] = qaAnn['image_id']
ann['question_type'] = qaAnn['question_type']
ann['answer_type'] = qaAnn['answer_type']
print('DONE (t=%0.2fs)' %
((datetime.datetime.utcnow() - time_t).total_seconds()))
res.dataset['annotations'] = anns
res.createIndex()
return res
\ No newline at end of file
"""Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
# coding=utf-8
__author__ = 'aagrawal'
import re
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
class VQAEval:
def __init__(self, vqa=None, vqaRes=None, n=2):
self.n = n
self.accuracy = {}
self.evalQA = {}
self.evalQuesType = {}
self.evalAnsType = {}
self.vqa = vqa
self.vqaRes = vqaRes
if vqa is not None:
self.params = {'question_id': vqa.getQuesIds()}
self.contractions = {
'aint': "ain't",
'arent': "aren't",
'cant': "can't",
'couldve': "could've",
'couldnt': "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
'didnt': "didn't",
'doesnt': "doesn't",
'dont': "don't",
'hadnt': "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
'hasnt': "hasn't",
'havent': "haven't",
'hed': "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
'hes': "he's",
'howd': "how'd",
'howll': "how'll",
'hows': "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
'Im': "I'm",
'Ive': "I've",
'isnt': "isn't",
'itd': "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
'itll': "it'll",
"let's": "let's",
'maam': "ma'am",
'mightnt': "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
'mightve': "might've",
'mustnt': "mustn't",
'mustve': "must've",
'neednt': "needn't",
'notve': "not've",
'oclock': "o'clock",
'oughtnt': "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
'shant': "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
'shouldve': "should've",
'shouldnt': "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": 'somebodyd',
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
'somebodyll': "somebody'll",
'somebodys': "somebody's",
'someoned': "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
'someonell': "someone'll",
'someones': "someone's",
'somethingd': "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
'somethingll': "something'll",
'thats': "that's",
'thered': "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
'therere': "there're",
'theres': "there's",
'theyd': "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
'theyll': "they'll",
'theyre': "they're",
'theyve': "they've",
'twas': "'twas",
'wasnt': "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
'weve': "we've",
'werent': "weren't",
'whatll': "what'll",
'whatre': "what're",
'whats': "what's",
'whatve': "what've",
'whens': "when's",
'whered': "where'd",
'wheres': "where's",
'whereve': "where've",
'whod': "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
'wholl': "who'll",
'whos': "who's",
'whove': "who've",
'whyll': "why'll",
'whyre': "why're",
'whys': "why's",
'wont': "won't",
'wouldve': "would've",
'wouldnt': "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
'yall': "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
'youd': "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
'youll': "you'll",
'youre': "you're",
'youve': "you've",
}
self.manualMap = {
'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10',
}
self.articles = ['a', 'an', 'the']
self.periodStrip = re.compile('(?!<=\d)(\.)(?!\d)')
self.commaStrip = re.compile('(\d)(,)(\d)')
self.punct = [
';',
r'/',
'[',
']',
'"',
'{',
'}',
'(',
')',
'=',
'+',
'\\',
'_',
'-',
'>',
'<',
'@',
'`',
',',
'?',
'!',
]
def evaluate(self, quesIds=None):
if quesIds == None:
quesIds = [quesId for quesId in self.params['question_id']]
gts = {}
res = {}
for quesId in quesIds:
gts[quesId] = self.vqa.qa[quesId]
res[quesId] = self.vqaRes.qa[quesId]
# =================================================
# Compute accuracy
# =================================================
accQA = []
accQuesType = {}
accAnsType = {}
print('computing accuracy')
step = 0
for quesId in quesIds:
resAns = res[quesId]['answer']
resAns = resAns.replace('\n', ' ')
resAns = resAns.replace('\t', ' ')
resAns = resAns.strip()
resAns = self.processPunctuation(resAns)
resAns = self.processDigitArticle(resAns)
gtAcc = []
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
if len(set(gtAnswers)) > 1:
for ansDic in gts[quesId]['answers']:
ansDic['answer'] = self.processPunctuation(
ansDic['answer'])
for gtAnsDatum in gts[quesId]['answers']:
otherGTAns = [
item for item in gts[quesId]['answers']
if item != gtAnsDatum
]
matchingAns = [
item for item in otherGTAns if item['answer'] == resAns
]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
quesType = gts[quesId]['question_type']
ansType = gts[quesId]['answer_type']
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
accQA.append(avgGTAcc)
if quesType not in accQuesType:
accQuesType[quesType] = []
accQuesType[quesType].append(avgGTAcc)
if ansType not in accAnsType:
accAnsType[ansType] = []
accAnsType[ansType].append(avgGTAcc)
self.setEvalQA(quesId, avgGTAcc)
self.setEvalQuesType(quesId, quesType, avgGTAcc)
self.setEvalAnsType(quesId, ansType, avgGTAcc)
if step % 100 == 0:
self.updateProgress(step / float(len(quesIds)))
step = step + 1
self.setAccuracy(accQA, accQuesType, accAnsType)
print('Done computing accuracy')
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + ' ' in inText or ' ' + p
in inText) or (re.search(self.commaStrip, inText) != None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = self.periodStrip.sub('', outText, re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = ' '.join(outText)
return outText
def setAccuracy(self, accQA, accQuesType, accAnsType):
self.accuracy['overall'] = round(100 * float(sum(accQA)) / len(accQA),
self.n)
self.accuracy['perQuestionType'] = {
quesType: round(
100 * float(sum(accQuesType[quesType])) /
len(accQuesType[quesType]),
self.n,
)
for quesType in accQuesType
}
self.accuracy['perAnswerType'] = {
ansType: round(
100 * float(sum(accAnsType[ansType])) /
len(accAnsType[ansType]), self.n)
for ansType in accAnsType
}
def setEvalQA(self, quesId, acc):
self.evalQA[quesId] = round(100 * acc, self.n)
def setEvalQuesType(self, quesId, quesType, acc):
if quesType not in self.evalQuesType:
self.evalQuesType[quesType] = {}
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
def setEvalAnsType(self, quesId, ansType, acc):
if ansType not in self.evalAnsType:
self.evalAnsType[ansType] = {}
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
def updateProgress(self, progress):
barLength = 20
status = ''
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = 'error: progress var must be float\r\n'
if progress < 0:
progress = 0
status = 'Halt...\r\n'
if progress >= 1:
progress = 1
status = 'Done...\r\n'
block = int(round(barLength * progress))
text = '\rFinshed Percent: [{0}] {1}% {2}'.format(
'#' * block + '-' * (barLength - block), int(progress * 100),
status)
sys.stdout.write(text)
sys.stdout.flush()
\ No newline at end of file
......@@ -6,7 +6,7 @@
from transformers import PretrainedConfig
class QwenConfig(PretrainedConfig):
class QWenConfig(PretrainedConfig):
model_type = "monkey"
keys_to_ignore_at_inference = ["past_key_values"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment