# Copyright (c) Facebook, Inc. and its affiliates. # copied from https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/m4c_evaluator.py import re from tqdm import tqdm class EvalAIAnswerProcessor: """ Processes an answer similar to Eval AI copied from https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 """ CONTRACTIONS = { 'aint': "ain't", 'arent': "aren't", 'cant': "can't", 'couldve': "could've", 'couldnt': "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", 'didnt': "didn't", 'doesnt': "doesn't", 'dont': "don't", 'hadnt': "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", 'hasnt': "hasn't", 'havent': "haven't", 'hed': "he'd", "hed've": "he'd've", "he'dve": "he'd've", 'hes': "he's", 'howd': "how'd", 'howll': "how'll", 'hows': "how's", "Id've": "I'd've", "I'dve": "I'd've", 'Im': "I'm", 'Ive': "I've", 'isnt': "isn't", 'itd': "it'd", "itd've": "it'd've", "it'dve": "it'd've", 'itll': "it'll", "let's": "let's", 'maam': "ma'am", 'mightnt': "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", 'mightve': "might've", 'mustnt': "mustn't", 'mustve': "must've", 'neednt': "needn't", 'notve': "not've", 'oclock': "o'clock", 'oughtnt': "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", 'shant': "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", 'shouldve': "should've", 'shouldnt': "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": 'somebodyd', "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", 'somebodyll': "somebody'll", 'somebodys': "somebody's", 'someoned': "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", 'someonell': "someone'll", 'someones': "someone's", 'somethingd': "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", 'somethingll': "something'll", 'thats': "that's", 'thered': "there'd", "thered've": "there'd've", "there'dve": "there'd've", 'therere': "there're", 'theres': "there's", 'theyd': "they'd", "theyd've": "they'd've", "they'dve": "they'd've", 'theyll': "they'll", 'theyre': "they're", 'theyve': "they've", 'twas': "'twas", 'wasnt': "wasn't", "wed've": "we'd've", "we'dve": "we'd've", 'weve': "we've", 'werent': "weren't", 'whatll': "what'll", 'whatre': "what're", 'whats': "what's", 'whatve': "what've", 'whens': "when's", 'whered': "where'd", 'wheres': "where's", 'whereve': "where've", 'whod': "who'd", "whod've": "who'd've", "who'dve": "who'd've", 'wholl': "who'll", 'whos': "who's", 'whove': "who've", 'whyll': "why'll", 'whyre': "why're", 'whys': "why's", 'wont': "won't", 'wouldve': "would've", 'wouldnt': "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", 'yall': "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", 'youd': "you'd", "youd've": "you'd've", "you'dve": "you'd've", 'youll': "you'll", 'youre': "you're", 'youve': "you've", } NUMBER_MAP = { 'none': '0', 'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10', } ARTICLES = ['a', 'an', 'the'] PERIOD_STRIP = re.compile(r'(?!<=\d)(\.)(?!\d)') COMMA_STRIP = re.compile(r'(?<=\d)(\,)+(?=\d)') PUNCTUATIONS = [ ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', '>', '<', '@', '`', ',', '?', '!', ] def __init__(self, *args, **kwargs): pass def word_tokenize(self, word): word = word.lower() word = word.replace(',', '').replace('?', '').replace("'s", " 's") return word.strip() def process_punctuation(self, in_text): out_text = in_text for p in self.PUNCTUATIONS: if (p + ' ' in in_text or ' ' + p in in_text) or ( re.search(self.COMMA_STRIP, in_text) is not None ): out_text = out_text.replace(p, '') else: out_text = out_text.replace(p, ' ') out_text = self.PERIOD_STRIP.sub('', out_text, re.UNICODE) return out_text def process_digit_article(self, in_text): out_text = [] temp_text = in_text.lower().split() for word in temp_text: word = self.NUMBER_MAP.setdefault(word, word) if word not in self.ARTICLES: out_text.append(word) else: pass for word_id, word in enumerate(out_text): if word in self.CONTRACTIONS: out_text[word_id] = self.CONTRACTIONS[word] out_text = ' '.join(out_text) return out_text def __call__(self, item): item = self.word_tokenize(item) item = item.replace('\n', ' ').replace('\t', ' ').strip() item = self.process_punctuation(item) item = self.process_digit_article(item) return item class TextVQAAccuracyEvaluator: def __init__(self): self.answer_processor = EvalAIAnswerProcessor() def _compute_answer_scores(self, raw_answers): """ compute the accuracy (soft score) of human answers """ answers = [self.answer_processor(a) for a in raw_answers] assert len(answers) == 10 gt_answers = list(enumerate(answers)) unique_answers = set(answers) unique_answer_scores = {} for unique_answer in unique_answers: accs = [] for gt_answer in gt_answers: other_answers = [item for item in gt_answers if item != gt_answer] matching_answers = [ item for item in other_answers if item[1] == unique_answer ] acc = min(1, float(len(matching_answers)) / 3) accs.append(acc) unique_answer_scores[unique_answer] = sum(accs) / len(accs) return unique_answer_scores def eval_pred_list(self, pred_list): pred_scores = [] for entry in tqdm(pred_list): pred_answer = self.answer_processor(entry['pred_answer']) unique_answer_scores = self._compute_answer_scores(entry['gt_answers']) score = unique_answer_scores.get(pred_answer, 0.0) pred_scores.append(score) accuracy = sum(pred_scores) / len(pred_scores) return accuracy class STVQAAccuracyEvaluator: def __init__(self): self.answer_processor = EvalAIAnswerProcessor() def eval_pred_list(self, pred_list): pred_scores = [] for entry in pred_list: pred_answer = self.answer_processor(entry['pred_answer']) gts = [self.answer_processor(a) for a in entry['gt_answers']] score = 1.0 if pred_answer in gts else 0.0 pred_scores.append(score) accuracy = sum(pred_scores) / len(pred_scores) return accuracy class STVQAANLSEvaluator: def __init__(self): import editdistance # install with `pip install editdistance` self.get_edit_distance = editdistance.eval def get_anls(self, s1, s2): s1 = s1.lower().strip() s2 = s2.lower().strip() iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2)) anls = iou if iou >= 0.5 else 0.0 return anls def eval_pred_list(self, pred_list): pred_scores = [] for entry in pred_list: anls = max( self.get_anls(entry['pred_answer'], gt) for gt in entry['gt_answers'] ) pred_scores.append(anls) accuracy = sum(pred_scores) / len(pred_scores) return accuracy class TextCapsBleu4Evaluator: def __init__(self): # The following script requires Java 1.8.0 and pycocotools installed. # The pycocoevalcap can be installed with pip as # pip install git+https://github.com/ronghanghu/coco-caption.git@python23 # Original pycocoevalcap code is at https://github.com/tylin/coco-caption # but has no python3 support yet. try: from pycocoevalcap.bleu.bleu import Bleu from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer except ModuleNotFoundError: print( 'Please install pycocoevalcap module using ' 'pip install git+https://github.com/ronghanghu/coco-caption.git@python23' # noqa ) raise self.tokenizer = PTBTokenizer() self.scorer = Bleu(4) def eval_pred_list(self, pred_list): # Create reference and hypotheses captions. gts = {} res = {} for idx, entry in enumerate(pred_list): gts[idx] = [{'caption': a} for a in entry['gt_answers']] res[idx] = [{'caption': entry['pred_answer']}] gts = self.tokenizer.tokenize(gts) res = self.tokenizer.tokenize(res) score, _ = self.scorer.compute_score(gts, res) bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4) return bleu4