Commit 3263c572 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'big-refactor' of https://github.com/EleutherAI/lm-evaluation-harness into squadv2

parents a27e8ed1 33d52483
......@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter):
name = "track_decontamination"
def __init__(self, path):
def __init__(self, path) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
......@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
"""
self._decontam_results = None
def apply(self, reps):
def apply(self, resps, docs) -> None:
"""
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
"""
......
......@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter):
""" """
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
......@@ -15,7 +17,7 @@ class RegexFilter(Filter):
self.regex = re.compile(regex_pattern)
self.fallback = fallback
def apply(self, resps):
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -41,12 +43,11 @@ class RegexFilter(Filter):
class WhitespaceFilter(Filter):
""" """
def __init__(self):
def __init__(self) -> None:
pass
def apply(self, resps):
def apply(self, resps, docs):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
......
......@@ -4,12 +4,12 @@ from lm_eval.api.filter import Filter
class TakeFirstFilter(Filter):
def __init__(self):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps):
def apply(self, resps, docs):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
......@@ -17,13 +17,12 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
def apply(self, resps):
def apply(self, resps, docs):
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
......@@ -32,12 +31,12 @@ class TakeKFilter(Filter):
class MajorityVoteFilter(Filter):
def __init__(self):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps):
def apply(self, resps, docs):
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
......
......@@ -6,3 +6,5 @@ logging.basicConfig(
level=logging.INFO,
)
eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47
......@@ -76,7 +76,7 @@ class AnthropicLM(LM):
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
) -> None:
"""Anthropic API wrapper.
:param model: str
......@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]:
if not requests:
return []
......
......@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy")
class DummyLM(LM):
def __init__(self):
def __init__(self) -> None:
super().__init__()
@classmethod
......
import os
import torch
import transformers
from transformers.models.auto.modeling_auto import (
......@@ -20,7 +22,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size
from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union
......@@ -67,6 +69,7 @@ class HFLM(LM):
revision: Optional[str] = "main",
subfolder: Optional[str] = None,
tokenizer: Optional[str] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
......@@ -75,6 +78,7 @@ class HFLM(LM):
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
cache_dir: Optional[Union[str, os.PathLike]] = None,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
parallelize: Optional[bool] = False,
......@@ -90,7 +94,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False,
):
) -> None:
super().__init__()
assert isinstance(device, str)
......@@ -103,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu", "mps"]
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
)
if device:
if device not in device_list:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device == "mps":
if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32."
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
)
else:
eval_logger.info("Device not specified")
......@@ -240,6 +247,8 @@ class HFLM(LM):
use_fast=use_fast_tokenizer,
)
self.truncation = truncation
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
......@@ -289,9 +298,16 @@ class HFLM(LM):
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
else:
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
......@@ -334,7 +350,7 @@ class HFLM(LM):
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
def max_gen_toks(self) -> int:
return 256
@property
......@@ -353,7 +369,7 @@ class HFLM(LM):
def world_size(self):
return self._world_size
def _detect_batch_size(self, requests=None, pos=0):
def _detect_batch_size(self, requests=None, pos: int = 0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len(
......@@ -419,7 +435,11 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
......@@ -432,6 +452,7 @@ class HFLM(LM):
encoding = self.tokenizer(
strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
......@@ -595,7 +616,9 @@ class HFLM(LM):
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -856,7 +879,9 @@ class HFLM(LM):
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
......
......@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine: str = "text-davinci-003",
truncate: bool = False,
batch_size: int = 1,
):
) -> None:
"""
:param engine: str
......@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return self.end_of_text_token_id
@property
def max_length(self):
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self):
def max_gen_toks(self) -> int:
return 256
@property
......@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(
self, requests, disable_tqdm=False
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
res = []
......
......@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@register_model("textsynth")
class TextSynthLM(LM):
def __init__(self, engine, truncate=False):
def __init__(self, engine, truncate: bool = False) -> None:
"""
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
......@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise NotImplementedError()
@property
def max_length(self):
def max_length(self) -> int:
# NOTE: Turn on truncation to avoid errors on long inputs.
return 2048
@property
def max_gen_toks(self):
def max_gen_toks(self) -> int:
return 256
@property
......
import ast
from typing import Dict
from lm_eval import utils
from lm_eval.logger import eval_logger
......@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY = {
PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:",
......@@ -13,7 +16,7 @@ PROMPT_REGISTRY = {
}
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
# unpack prompt name
category_name, prompt_name = prompt_id.split(":")
if subset_name is None:
......@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
category_name, prompt_name = use_prompt.split(":")
category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list]
......@@ -5,8 +5,8 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] Glue
- [x] SuperGlue
- [ ] CoQA (Lintang)
- [ ] DROP (Lintang)
- [x] CoQA
- [x] DROP
- [x] ~~Lambada~~
- [x] Lambada (Cloze variants)
- [x] ~~Lambada (Multilingual)~~
......@@ -29,7 +29,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] HeadQA
- [x] MathQA
- [x] WebQs
- [ ] WSC273 (Lintang)
- [x] WSC273
- [x] Winogrande
- [x] ANLI
- [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: <https://github.com/EleutherAI/lm-evaluation-harness/pull/660> for more info)
......@@ -38,11 +38,11 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (gen)
- [ ] MuTual
- [ ] Hendrycks Math (Hailey)
- [ ] Asdiv
- [x] Asdiv
- [ ] GSM8k
- [x] Arithmetic
- [ ] MMMLU (Hailey)
- [ ] Translation (WMT) suite (Hailey)
- [x] Translation (WMT) suite
- [x] Unscramble
- [x] ~~Pile (perplexity)~~
- [x] BLiMP
......@@ -56,7 +56,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] XWinograd
- [x] PAWS-X
- [x] XNLI
- [ ] MGSM (Lintang)
- [x] MGSM
- [ ] SCROLLS
- [x] Babi
......
import os
import yaml
from typing import List, Union
from typing import List, Union, Dict
from lm_eval import utils
from lm_eval import prompts
......@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
)
def register_configurable_task(config):
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
......@@ -38,7 +38,7 @@ def register_configurable_task(config):
return 0
def check_prompt_config(config):
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
......@@ -69,14 +69,14 @@ def check_prompt_config(config):
return all_configs
def get_task_name_from_config(task_config):
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir):
def include_task_folder(task_dir: str) -> None:
"""
Calling this function
"""
......@@ -128,7 +128,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs}
......@@ -136,6 +136,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict = {}
task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list:
if isinstance(task_element, str):
......@@ -143,12 +146,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: (
group_name,
get_task(task_name=task_name, config=config),
),
**task_dict,
}
else:
task_name = task_element
......
task: asdiv
dataset_path: EleutherAI/asdiv
output_type: loglikelihood
validation_split: validation
doc_to_text: "{{body}}\nQuestion:{{question}}\nAnswer:"
doc_to_target: "{{answer.split(' (')[0]}}"
should_decontaminate: true
doc_to_decontamination_query: "{{body}} {{question}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# CoQA
### Paper
Title: `CoQA: A Conversational Question Answering Challenge`
Abstract: https://arxiv.org/pdf/1808.07042.pdf
CoQA is a large-scale dataset for building Conversational Question Answering
systems. The goal of the CoQA challenge is to measure the ability of machines to
understand a text passage and answer a series of interconnected questions that
appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/
### Citation
```
BibTeX-formatted citation goes here
```
### Groups and Tasks
#### Groups
* Not part of a group yet
#### Tasks
* `coqa`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: coqa
dataset_path: EleutherAI/coqa
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{story}} {{question.input_text|join('\n')}}"
generation_kwargs:
until:
- "\nQ:"
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
def em(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
return em_sum / max(1, len(gold_list))
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def process_results(doc, results):
gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0]
scores = compute_scores(gold_list, pred)
return scores
# DROP
### Paper
Title: `DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs`
Abstract: https://aclanthology.org/attachments/N19-1246.Supplementary.pdf
DROP is a QA dataset which tests comprehensive understanding of paragraphs. In
this crowdsourced, adversarially-created, 96k question-answering benchmark, a
system must resolve multiple references in a question, map them onto a paragraph,
and perform discrete operations over them (such as addition, counting, or sorting).
Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
### Citation
```
@misc{dua2019drop,
title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs},
author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
year={2019},
eprint={1903.00161},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
### Groups and Tasks
#### Groups
* Not part of a group yet.
#### Tasks
* `drop`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: drop
dataset_path: EleutherAI/drop
output_type: greedy_until
training_split: train
validation_split: validation
process_docs: !function utils.process_docs
doc_to_text: "{{passage}} {{question}}"
doc_to_target: "{{ answer|join(',')}}"
target_delimiter: ""
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{passage}} {{question}}"
generation_kwargs:
until:
- "."
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
import re
import string
import numpy as np
from scipy.optimize import linear_sum_assignment
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
def process_docs(dataset):
def _process(doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": get_answers(doc),
}
return dataset.map(_process)
def get_answers(doc):
def _flatten_validated_answers(validated_answers):
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
valid_answers = []
for i in range(len(validated_answers["number"])):
valid_answers.append(
{
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
}
)
return valid_answers
answers = []
answers_set = set()
candidates = [doc["answer"]] + _flatten_validated_answers(doc["validated_answers"])
for candidate in candidates:
answer = parse_answer(candidate)
if answer in answers_set:
continue
answers_set.add(answer)
answers.append(answer)
return answers
def parse_answer(answer):
# NOTE: Everything is returned as a tuple for uniformity and hashability.
if answer["number"] != "":
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (
" ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def process_results(doc, results):
preds, golds = results, doc["answers"]
max_em = 0
max_f1 = 0
for gold_answer in golds:
exact_match, f1_score = get_metrics(preds, gold_answer)
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {"em": max_em, "f1": max_f1}
def get_metrics(predicted, gold):
"""
Takes a predicted answer and a gold answer (that are both either a string or a list of
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
validation, or while training), this is the function you want to call, after using
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
"""
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(
gold_bags[0]
):
exact_match = 1.0
else:
exact_match = 0.0
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return exact_match, f1
def _answer_to_bags(answer):
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans = []
token_bags = []
for raw_span in raw_spans:
normalized_span = _normalize(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(predicted, gold):
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if _match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def _compute_f1(predicted_bag, gold_bag):
intersection = len(gold_bag.intersection(predicted_bag))
if not predicted_bag:
precision = 1.0
else:
precision = intersection / float(len(predicted_bag))
if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1
def _match_numbers_if_present(gold_bag, predicted_bag):
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
if _is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if _is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True
return False
def _is_number(text):
try:
float(text)
return True
except ValueError:
return False
def _remove_articles(text):
return _ARTICLES.sub(" ", text)
def _white_space_fix(text):
return " ".join(text.split())
def _remove_punc(text):
exclude = set(string.punctuation)
if not _is_number(text):
return "".join(ch for ch in text if ch not in exclude)
else:
return text
def _fix_number(text):
return str(float(text)) if _is_number(text) else text
def _tokenize(text):
return re.split(" |-", text)
def _normalize(answer):
tokens = [
_white_space_fix(_remove_articles(_fix_number(_remove_punc(token.lower()))))
for token in _tokenize(answer)
]
tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip()
return normalized
def doc_to_text(doc):
def doc_to_text(doc) -> str:
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"].strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment