import argparse import json from evaluate_mmmu import get_input_output_paths from open_flamingo.eval.vqa_metric import VQAEval # ANLS score calculation based on https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/dist.py#L1 # and https://github.com/shunk031/ANLS/blob/6472e1d71e84d6cee28e3c6d2e18564bafaa312d/anls/metrics/score.py#L6 # MIT License. Copyright (c) 2022 Shunsuke KITADA def levenshtein_distance(s1: str, s2: str) -> int: if len(s1) > len(s2): s1, s2 = s2, s1 distances = list(range(len(s1) + 1)) for i2, c2 in enumerate(s2): dists = [i2 + 1] for i1, c1 in enumerate(s1): if c1 == c2: dists.append(distances[i1]) else: dists.append(1 + min((distances[i1], distances[i1 + 1], dists[-1]))) distances = dists return distances[-1] def normalized_levenshtein_distance(s1: str, s2: str) -> float: dist = levenshtein_distance(s1, s2) length = max(len(s1.upper()), len(s2.upper())) return 0.0 if length == 0 else dist / length def similarity_function(prediction: str, gold_label: str, threshold: float) -> float: nl_score = normalized_levenshtein_distance(prediction, gold_label) return 1 - nl_score if nl_score < threshold else 0.0 def anls_score( prediction: str, gold_labels: List[str], threshold: float = 0.5 ) -> float: # not case sensitive, but space sensitive y_pred = " ".join(prediction.strip().lower().split()) anls_scores: List[float] = [] for gold_label in gold_labels: # not case sensitive, but space sensitive y_true = " ".join(gold_label.strip().lower().split()) anls_score = similarity_function(y_pred, y_true, threshold) anls_scores.append(anls_score) score = max(anls_scores) return score def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2") results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) sample_id = res["sample_id"] # Skip possible duplicates. if sample_id in results: continue res["question_id"] = sample_id results[sample_id] = res results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) return output_file_path def is_number(n: str): """Check if input is a number.""" try: float(n) return True except ValueError: return False def compute_vqa_accuracy(result_file, task): """Compute VQA accuracy.""" merged_results = json.load(open(result_file)) vqa = VQAEval(vqa=None, vqaRes=None) all_acc = [] for res in merged_results: pred = res["answer"] pred = vqa.processPunctuation(pred) pred = vqa.processDigitArticle(pred) gt = res["gt_answer"] gt = [vqa.processPunctuation(ans) for ans in gt] gt = [vqa.processDigitArticle(ans) for ans in gt] # ChartQA uses relaxed accuracy: # "We consider an answer to be correct if it is within 5% of the gold answer. # For non-numeric answers, we still need an exact match to consider an answer to be correct." if task == "ChartQA": acc = 0.0 assert len(gt) == 1, "expected exactly one groundtruth answer." gt = gt[0] pred = pred.rstrip("%") gt = gt.rstrip("%") if is_number(pred) and is_number(gt): pred = float(pred) gt = float(gt) if pred >= (gt * 0.95) and pred <= (gt * 1.05): acc = 1.0 elif pred == gt: acc = 1.0 all_acc.append(acc) elif task in ("VQAv2", "TextVQA"): num_match = sum([pred == ans for ans in gt]) acc = min(1.0, num_match / 3.0) all_acc.append(acc) elif task in ("SPDocVQA", "InfoVQA"): acc = anls_score(prediction=pred, gold_labels=gt, threshold=0.5) all_acc.append(acc) elif task == "AI2D": assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}" acc = pred == gt[0] all_acc.append(acc) else: raise NotImplementedError(f"unknown task {task}") acc_avg = sum(all_acc) / len(all_acc) * 100 return acc_avg def vqav2_eval(input_path): """Run VQAv2 evaluation.""" result_file = merge_input_files(input_path) avg_acc = compute_vqa_accuracy(result_file, task="VQAv2") return avg_acc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--input-path', type=str, help="Path to input file(s)") args = parser.parse_args() avg_acc = vqav2_eval(args.input_path) print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====")