"vscode:/vscode.git/clone" did not exist on "1612d8f34d505cd550f2bbab98e2274a8638935c"
Commit c21240c0 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'big-refactor' of https://github.com/EleutherAI/lm-evaluation-harness into alt_worlds

parents bbd6ab3a afda6551
......@@ -59,6 +59,8 @@ my_model = initialize_my_model() # create your model (could be running finetunin
...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
lm_eval.tasks.initialize_tasks() # register all tasks from the `lm_eval/tasks` subdirectory. Alternatively, can call `lm_eval.tasks.include_path("path/to/my/custom/task/configs")` to only register a set of tasks in a separate directory.
results = lm_eval.simple_evaluate( # call simple_evaluate
model=lm_obj,
tasks=["taskname1", "taskname2"],
......@@ -85,7 +87,7 @@ my_model = initialize_my_model() # create your model (could be running finetunin
...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
lm_eval.tasks.initialize_tasks() # register all tasks from the `lm_eval/tasks` subdirectory. Alternatively, can call `lm_eval.tasks.include_path("path/to/my/custom/task/configs")` to only register a set of tasks in a separate directory.
def evaluate(
lm=lm_obj,
......
import os
import re
import sys
import json
import fnmatch
import argparse
import logging
from pathlib import Path
import argparse
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING
from lm_eval.tasks import include_path
from pathlib import Path
from typing import Union
from lm_eval import evaluator, utils
from lm_eval.tasks import initialize_tasks, include_path
from lm_eval.api.registry import ALL_TASKS
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
......@@ -25,11 +25,11 @@ def _handle_non_serializable(o):
def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument("--model", default="hf", help="Name of model e.g. `hf`")
parser.add_argument(
"--tasks",
default=None,
help="Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS))),
help="To get full list of tasks, use the command lm-eval --tasks list",
)
parser.add_argument(
"--model_args",
......@@ -119,9 +119,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
# we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args()
eval_logger = utils.eval_logger
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
initialize_tasks(args.verbosity)
if args.limit:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
......@@ -133,6 +137,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.tasks is None:
task_names = ALL_TASKS
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS)))
)
sys.exit()
else:
if os.path.isdir(args.tasks):
import glob
......@@ -149,16 +158,20 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
task_missing = [task for task in tasks_list if task not in task_names]
task_missing = [
task
for task in tasks_list
if task not in task_names and "*" not in task
] # we don't want errors if a wildcard ("*") task name was used
if task_missing:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks",
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
)
raise ValueError(
f"Tasks {missing} were not found. Try `lm-eval -h` for list of available tasks."
f"Tasks {missing} were not found. Try `lm-eval --tasks list` for list of available tasks."
)
if args.output_path:
......
......@@ -9,6 +9,9 @@ import evaluate
from lm_eval.api.registry import register_metric, register_aggregation
import logging
eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First
@register_aggregation("mean")
......
......@@ -10,7 +10,10 @@ import hashlib
from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
import logging
eval_logger = logging.getLogger("lm-eval")
T = TypeVar("T", bound="LM")
......
import os
import evaluate
from lm_eval.api.model import LM
from lm_eval.logger import eval_logger
import logging
eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY = {}
......
......@@ -4,6 +4,7 @@ from dataclasses import dataclass, field, asdict
import re
import ast
import yaml
import logging
import evaluate
import random
import itertools
......@@ -21,7 +22,6 @@ from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import (
......@@ -48,6 +48,9 @@ ALL_OUTPUT_TYPES = [
]
eval_logger = logging.getLogger("lm-eval")
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -91,7 +94,7 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if "." in self.dataset_path:
if self.dataset_path and ("." in self.dataset_path):
import inspect
from importlib import import_module
......@@ -204,19 +207,9 @@ class Task(abc.ABC):
self._fewshot_docs = None
self._instances = None
self._config = TaskConfig(**config) if config else TaskConfig()
self._config = TaskConfig({**config}) if config else TaskConfig()
if not hasattr(self, "_filters"):
self._filters = []
for name, components in self._config.get(
"filters", [["none", [["take_first", None]]]]
):
filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random(1234)
)
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
"""Downloads and returns the task dataset.
......@@ -357,9 +350,7 @@ class Task(abc.ABC):
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(
f"Building contexts for task '{self.config.task}' on rank {rank}..."
)
eval_logger.info(f"Building contexts for task on rank {rank}...")
instances = []
for doc_id, doc in utils.create_iterator(
......@@ -449,7 +440,13 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc))
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot):
def fewshot_context(
self,
doc,
num_fewshot,
rnd=random.Random(1234),
description=None,
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -457,34 +454,56 @@ class Task(abc.ABC):
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
description = description if description else ""
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
labeled_examples = ""
else:
labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
if type(example) == str:
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
elif type(example) == int:
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
return description + labeled_examples + example
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
f.apply(self._instances, None)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......@@ -764,6 +783,39 @@ class ConfigurableTask(Task):
)
return super().fewshot_docs()
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:returns: str
The fewshot context.
"""
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
else:
labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot
)
example = self.doc_to_text(doc)
if type(example) == str:
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
elif type(example) == int:
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
......
......@@ -20,10 +20,9 @@ from lm_eval.utils import (
make_table,
create_iterator,
get_git_commit_hash,
eval_logger,
)
from lm_eval.logger import eval_logger
@positional_deprecated
def simple_evaluate(
......@@ -226,6 +225,7 @@ def evaluate(
versions[group_name] = "N/A"
else:
group_name = None
task_hierarchy[task_name] = []
if task is None:
......@@ -237,8 +237,10 @@ def evaluate(
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
if ("group_alias" in configs[task_name]) and (
group_name not in task_group_alias
if (
("group_alias" in configs[task_name])
and (group_name not in task_group_alias)
and (group_name is not None)
):
task_group_alias[group_name] = configs[task_name]["group_alias"]
......@@ -253,7 +255,7 @@ def evaluate(
task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)
eval_logger.info(
eval_logger.debug(
f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
)
......@@ -268,12 +270,9 @@ def evaluate(
eval_logger.info(f"Request: {str(inst)}")
# aggregate Instances by LM method requested to get output.
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
for instance in task.instances:
reqtype = instance.request_type
requests[reqtype].append(instance)
if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
......
import logging
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47
......@@ -2,9 +2,11 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time
from lm_eval.logger import eval_logger
from lm_eval import utils
from typing import List, Any, Tuple
eval_logger = utils.eval_logger
def anthropic_completion(
client, #: anthropic.Anthropic,
......
......@@ -16,7 +16,6 @@ from pathlib import Path
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
......@@ -25,6 +24,8 @@ from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union
eval_logger = utils.eval_logger
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
......
......@@ -3,7 +3,7 @@ import ast
from typing import Dict
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.utils import eval_logger
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
......
......@@ -15,7 +15,18 @@ from lm_eval.api.registry import (
import logging
eval_logger = logging.getLogger("lm-eval")
# import python tasks
from .squadv2.task import SQuAD2
from .scrolls.task import (
QuALITY,
NarrativeQA,
ContractNLI,
GovReport,
SummScreenFD,
QMSum,
)
eval_logger = utils.eval_logger
def register_configurable_task(config: Dict[str, str]) -> int:
......@@ -141,8 +152,11 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
else:
if type(config["task"]) == list:
register_configurable_group(config, yaml_path)
# Log this silently and show it only when
# the user defines the appropriate verbosity.
except ModuleNotFoundError as e:
eval_logger.warning(
eval_logger.debug(
f"{yaml_path}: {e}. Config will not be added to registry."
)
except Exception as error:
......@@ -165,8 +179,12 @@ def include_path(task_dir):
return 0
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_path(task_dir)
def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_path(task_dir)
def get_task(task_name, config):
......
group: bigbench
dataset_path: bigbench # will switch to `hails/bigbench` when all tasks are pushed
group: bigbench_generate_until
dataset_path: hails/bigbench
output_type: generate_until
dataset_kwargs:
# num_shots: 0 # TODO: num of shots for `bigbench` HF dataset should be controlled through this, not through the typical methods
......
group: bigbench
dataset_path: bigbench # will switch to `hails/bigbench` when all tasks are pushed
group: bigbench_multiple_choice
dataset_path: hails/bigbench
dataset_kwargs:
# num_shots: 0 # TODO: num of shots for `bigbench` HF dataset should be controlled through this, not through the typical methods
# subtask_name: null
......
import datasets
import re
import signal
from lm_eval.logger import eval_logger
from lm_eval.utils import eval_logger
from typing import Optional, List, Dict
try:
......
......@@ -3,7 +3,7 @@ import json
import requests
import numpy as np
from lm_eval.logger import eval_logger
from lm_eval.utils import eval_logger
def toxicity_perspective_api(references, predictions, **kwargs):
......
"""
SCROLLS: Standardized CompaRison Over Long Language Sequences
https://arxiv.org/abs/2201.03533
SCROLLS is a suite of datasets that require synthesizing information over long texts.
The benchmark includes seven natural language tasks across multiple domains,
including summarization, question answering, and natural language inference.
Homepage: https://www.scrolls-benchmark.com/
Since SCROLLS tasks are generally longer than the maximum sequence length of many models,
it is possible to create "subset" tasks that contain only those samples whose tokenized length
is less than some pre-defined limit. For example, to create a subset of "Qasper" that would
be suitable for a model using the GPTNeoX tokenizer and a 4K maximium sequence length:
```
class QasperGPTNeoX4K(Qasper):
PRUNE_TOKENIZERS = ["EleutherAI/pythia-410m-deduped"]
PRUNE_MAX_TOKENS = 4096
PRUNE_NUM_PROC = _num_cpu_cores() # optional, to speed up pruning of large datasets like NarrativeQA
```
`PRUNE_TOKENIZERS` can contain more than one tokenizer; this will include only samples that are
less than `PRUNE_MAX_TOKENS` for ALL of the tokenizers. This can be useful to comparing models
that use different tokenizers but the same maximum sequence length.
Once the subset task class has been defined in this file, it can be used by adding the class
to `lm_eval/tasks/__init__.py`.
NOTE: GovReport may need `max_gen_toks` set larger for causal models.
"""
group: scrolls
task:
- scrolls_qasper
- scrolls_quality
- scrolls_narrativeqa
- scrolls_contractnli
- scrolls_govreport
- scrolls_summscreenfd
- scrolls_qmsum
import re
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
from abc import abstractmethod
from datasets import load_metric
from transformers import AutoTokenizer
from functools import reduce
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
_CITATION = """
@inproceedings{shaham-etal-2022-scrolls,
title = "{SCROLLS}: Standardized {C}ompa{R}ison Over Long Language Sequences",
author = "Shaham, Uri and
Segal, Elad and
Ivgi, Maor and
Efrat, Avia and
Yoran, Ori and
Haviv, Adi and
Gupta, Ankit and
Xiong, Wenhan and
Geva, Mor and
Berant, Jonathan and
Levy, Omer",
booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2022",
address = "Abu Dhabi, United Arab Emirates",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.emnlp-main.823",
pages = "12007--12021"
}
"""
# SCROLLS is formualted as a sequence-to-sequence task.
# To allow for evaluation of causal models, we'll
# reformualte these with appropriate prompts
def _download_metric():
import os
import shutil
from huggingface_hub import hf_hub_download
scrolls_metric_path = hf_hub_download(
repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py"
)
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path)
+ os.path.basename(scrolls_metric_path).replace(".", "_")
+ ".py"
)
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
return updated_scrolls_metric_path
def _process_doc_prepended_question(doc):
# "When a query is given in addition to the raw text (as
# in QMSum, Qasper, NarrativeQA, QuALITY, and ContractNLI),
# we prepend it to the text, using two newlines as a natural separator"
input = doc["input"]
split = input.find("\n\n")
return {
"id": doc["id"],
"pid": doc["pid"],
"input": input,
"outputs": doc["outputs"],
"question": input[0:split],
"text": input[split + 2 :],
}
def _drop_duplicates_in_input(untokenized_dataset):
# from scrolls/evaluator/dataset_evaluator.py
indices_to_keep = []
id_to_idx = {}
outputs = []
for i, (id_, output) in enumerate(
zip(untokenized_dataset["id"], untokenized_dataset["output"])
):
if id_ in id_to_idx:
outputs[id_to_idx[id_]].append(output)
continue
indices_to_keep.append(i)
id_to_idx[id_] = len(outputs)
outputs.append([output])
untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices()
untokenized_dataset = untokenized_dataset.remove_columns("output")
untokenized_dataset = untokenized_dataset.add_column("outputs", outputs)
return untokenized_dataset
def _num_cpu_cores():
# https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
try:
import psutil
return psutil.cpu_count(logical=False)
except ImportError:
import os
return len(os.sched_getaffinity(0))
class _SCROLLSTask(Task):
VERSION = 0
DATASET_PATH = "tau/scrolls"
DATASET_NAME = None
PRUNE_TOKENIZERS = None
PRUNE_MAX_TOKENS = None
PRUNE_NUM_PROC = None
def __post_init__(self):
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME)
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
for doc in self.dataset["train"]:
yield from self._process_doc(doc)
def validation_docs(self):
for doc in self.dataset["validation"]:
yield from self._process_doc(doc)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["input"]
def download(self, *args, **kwargs):
super().download(*args, **kwargs)
del self.dataset["test"]
for split in self.dataset:
self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
if self.PRUNE_TOKENIZERS is not None and self.PRUNE_TOKENIZERS is not None:
self.prune()
def _get_prune_text(self, sample):
return self.doc_to_text(self._process_doc(sample)[0])
def prune(self):
"""Create a pruned version of a SCROLLS task dataset containing only inputs
that are less than `max_tokens` when tokenized by each tokenizer
"""
tokenizers = [
AutoTokenizer.from_pretrained(tokenizer)
for tokenizer in self.PRUNE_TOKENIZERS
]
cache = {}
def _filter(sample):
text = self._get_prune_text(sample)
cached = cache.get(text, None)
if cached is None:
for tokenizer in tokenizers:
if len(tokenizer(text).input_ids) > self.PRUNE_MAX_TOKENS:
cache[text] = False
return False
cache[text] = True
return True
else:
return cached
self.dataset = self.dataset.filter(_filter, num_proc=self.PRUNE_NUM_PROC)
def doc_to_target(self, doc):
return " " + ", ".join(doc["outputs"])
def doc_to_text(self, doc):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
def higher_is_better(self):
return {x: True for x in self._scrolls_metrics().keys()}
@abstractmethod
def _scrolls_metrics(self):
pass
def _make_compute_metrics(self, value):
def compute_metrics(samples):
predictions, references = zip(*samples) # unzip, if you will
computed = self.metric.compute(
predictions=predictions, references=references
)
return computed[value]
return compute_metrics
def aggregation(self):
return {
key: self._make_compute_metrics(value)
for key, value in self._scrolls_metrics().items()
}
class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
def __post_init__(self):
self.metric = None
def _scrolls_metrics(self):
return None
def aggregation(self):
return {"em": mean, "acc": mean, "acc_norm": mean}
def higher_is_better(self):
return {"em": True, "acc": True, "acc_norm": True}
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return {
"acc": acc,
"acc_norm": acc_norm,
"em": acc_norm * 100.0,
}
def construct_requests(self, doc, ctx, **kwargs):
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in doc["choices"]
]
return request_list
class _SCROLLSSummaryTask(_SCROLLSTask):
def _process_doc(self, doc):
return [doc]
def _scrolls_metrics(self):
return {
"rouge1": "rouge/rouge1",
"rouge2": "rouge/rouge2",
"rougeL": "rouge/rougeL",
}
def process_results(self, doc, results):
return {
"rouge1": (results[0], doc["outputs"]),
"rouge2": (results[0], doc["outputs"]),
"rougeL": (results[0], doc["outputs"]),
}
def construct_requests(self, doc, ctx, **kwargs):
return Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n"]}),
idx=0,
**kwargs,
)
def doc_to_text(self, doc):
return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:"
@register_task("scrolls_qasper")
class Qasper(_SCROLLSTask):
"""A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers
https://arxiv.org/abs/2105.03011
"""
DATASET_NAME = "qasper"
def _process_doc(self, doc):
doc = _process_doc_prepended_question(doc)
doc["is_yes_no"] = reduce(
lambda prev, cur: prev
and squad_metrics.normalize_answer(cur) in ["yes", "no"],
doc["outputs"],
True,
)
return [doc]
def _scrolls_metrics(self):
return {"f1": "f1"}
def process_results(self, doc, results):
if doc["is_yes_no"]:
prediction = " yes" if results[0] > results[1] else " no"
elif len(results[0].strip()) == 0:
prediction = "Unanswerable"
else:
prediction = results[0]
return {"f1": (prediction, doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs):
if doc["is_yes_no"]:
return [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " yes"),
idx=0,
**kwargs,
),
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " no"),
idx=1,
**kwargs,
),
]
else:
return Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n"]}),
idx=0,
**kwargs,
)
@register_task("scrolls_quality")
class QuALITY(_SCROLLSMultipleChoiceTask):
"""QuALITY: Question Answering with Long Input Texts, Yes!
https://arxiv.org/abs/2112.08608
"""
DATASET_NAME = "quality"
_multiple_choice_pattern = re.compile(r" *\([A-D]\) *")
@staticmethod
def _normalize_answer(text):
return " ".join(text.split()).strip()
def _process_doc(self, doc):
doc = _process_doc_prepended_question(doc)
split = doc["text"].find("\n\n", doc["text"].find("(D)"))
choices_text = doc["text"][:split]
doc["text"] = doc["text"][split:].strip()
doc["choices"] = [
QuALITY._normalize_answer(choice)
for choice in re.split(QuALITY._multiple_choice_pattern, choices_text)[1:]
]
doc["gold"] = doc["choices"].index(QuALITY._normalize_answer(doc["outputs"][0]))
return [doc]
@register_task("scrolls_narrativeqa")
class NarrativeQA(_SCROLLSTask):
"""The NarrativeQA Reading Comprehension Challenge
https://arxiv.org/abs/1712.07040
"""
DATASET_NAME = "narrative_qa"
def _process_doc(self, doc):
return [_process_doc_prepended_question(doc)]
def _scrolls_metrics(self):
return {"f1": "f1"}
def _get_prune_text(self, doc):
# pruning narrativeqa takes forever -- let's cheat a bit
# and just cache on the text, not the question, since
# the dataset is different questions about the same large
# documents
return self._process_doc(doc)[0]["text"]
def process_results(self, doc, results):
return {"f1": (results[0], doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs):
return Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n"]}),
idx=0,
**kwargs,
)
@register_task("scrolls_contractnli")
class ContractNLI(_SCROLLSMultipleChoiceTask):
"""ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
https://arxiv.org/abs/1712.07040
"""
DATASET_NAME = "contract_nli"
CHOICES = ["Not mentioned", "Entailment", "Contradiction"]
def _process_doc(self, doc):
doc = _process_doc_prepended_question(doc)
doc["choices"] = ContractNLI.CHOICES
doc["gold"] = ContractNLI.CHOICES.index(doc["outputs"][0])
return [doc]
def doc_to_text(self, doc):
return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:"
@register_task("scrolls_govreport")
class GovReport(_SCROLLSSummaryTask):
"""Efficient Attentions for Long Document Summarization
https://arxiv.org/abs/2104.02112
Note: The average length of the reference summaries is ~3,000
characters, or ~600 tokens as tokenized by GPT-NeoX. For causal models,
it is recommended to set `max_gen_toks` sufficently large (e.g. 1024)
to allow a full summary to be generated.
"""
DATASET_NAME = "gov_report"
@register_task("scrolls_summscreenfd")
class SummScreenFD(_SCROLLSSummaryTask):
"""SummScreen: A Dataset for Abstractive Screenplay Summarization
https://arxiv.org/abs/2104.07091
"""
DATASET_NAME = "summ_screen_fd"
@register_task("scrolls_qmsum")
class QMSum(_SCROLLSSummaryTask):
"""QMSum: A New Benchmark for Query-based Multi-domain
Meeting Summarization
https://arxiv.org/abs/2104.05938
"""
DATASET_NAME = "qmsum"
def _process_doc(self, doc):
return [_process_doc_prepended_question(doc)]
def doc_to_text(self, doc):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
......@@ -34,12 +34,11 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/
#### Groups
* `squadv2_complete`: Runs both `squadv2` and `squadv2_noans_loglikelihood`
* Not part of a group yet
#### Tasks
* `squadv2`: `Default squadv2 task`
* `squadv2_noans_loglikelihood`: `Additional task to acquire the probability of model predicting there is no answer`
### Checklist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment