Commit 8bf55a20 authored by lintangsutawika's avatar lintangsutawika
Browse files

add squad from master

parent b7a4ea06
...@@ -91,7 +91,7 @@ class TaskConfig(dict): ...@@ -91,7 +91,7 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None: def __post_init__(self) -> None:
if "." in self.dataset_path: if self.dataset_path and ("." in self.dataset_path):
import inspect import inspect
from importlib import import_module from importlib import import_module
...@@ -204,19 +204,19 @@ class Task(abc.ABC): ...@@ -204,19 +204,19 @@ class Task(abc.ABC):
self._fewshot_docs = None self._fewshot_docs = None
self._instances = None self._instances = None
self._config = TaskConfig(**config) if config else TaskConfig() self._config = (
TaskConfig(
{
**config,
**{"dataset_path": DATASET_PATH, "dataset_name": DATASET_NAME},
}
)
if config
else TaskConfig()
)
if not hasattr(self, "_filters"): self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self._filters = []
for name, components in self._config.get(
"filters", [["none", [["take_first", None]]]]
):
filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random(1234)
)
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None: def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
...@@ -358,7 +358,7 @@ class Task(abc.ABC): ...@@ -358,7 +358,7 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info( eval_logger.info(
f"Building contexts for task '{self.config.task}' on rank {rank}..." f"Building contexts for task on rank {rank}..."
) )
instances = [] instances = []
...@@ -449,7 +449,9 @@ class Task(abc.ABC): ...@@ -449,7 +449,9 @@ class Task(abc.ABC):
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=random.Random(1234), description=None
):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -457,34 +459,68 @@ class Task(abc.ABC): ...@@ -457,34 +459,68 @@ class Task(abc.ABC):
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int :param num_fewshot: int
The number of fewshot examples to provide in the returned context string. The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0: if num_fewshot == 0:
# always prepend the (possibly empty) task description labeled_examples = ""
labeled_examples = self.config.description
else: else:
labeled_examples = self.config.description + self.sampler.get_context( # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
doc, num_fewshot if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if type(example) == str: return description + labeled_examples + example
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
elif type(example) == int:
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances, None)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning("No filter defined, passing through instances")
return self._instances return self._instances
...@@ -764,6 +800,41 @@ class ConfigurableTask(Task): ...@@ -764,6 +800,41 @@ class ConfigurableTask(Task):
) )
return super().fewshot_docs() return super().fewshot_docs()
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:returns: str
The fewshot context.
"""
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
else:
labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot
)
example = self.doc_to_text(doc)
if type(example) == str:
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
elif type(example) == int:
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
......
...@@ -268,12 +268,9 @@ def evaluate( ...@@ -268,12 +268,9 @@ def evaluate(
eval_logger.info(f"Request: {str(inst)}") eval_logger.info(f"Request: {str(inst)}")
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = ( for instance in task.instances:
"loglikelihood" reqtype = instance.request_type
if task.OUTPUT_TYPE == "multiple_choice" requests[reqtype].append(instance)
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
if lm.world_size > 1: if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device=lm.device) instances_rnk = torch.tensor(len(task._instances), device=lm.device)
......
...@@ -15,6 +15,8 @@ from lm_eval.api.registry import ( ...@@ -15,6 +15,8 @@ from lm_eval.api.registry import (
import logging import logging
from .squad import SQuAD2
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
......
import datasets
from math import exp
from functools import partial
from packaging import version
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references).get(key, 0)
@register_task("squadv2")
class SQuAD2(Task):
VERSION = 1
DATASET_PATH = "squad_v2"
DATASET_NAME = None
# HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse(
"1.11.0"
), "datasets v1.11.0 or later required for SQuAD"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return (
"Title: "
+ doc["title"]
+ "\n\n"
+ "Background: "
+ doc["context"]
+ "\n\n"
+ "Question: "
+ doc["question"]
+ "\n\n"
+ "Answer:"
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc):
answer_list = doc["answers"]["text"]
if len(answer_list) > 0:
answer = answer_list[0]
else:
answer = "unanswerable"
return " " + answer
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return [
Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n"]}),
idx=0,
**kwargs
),
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " " + "unanswerable"),
idx=0,
**kwargs
)
]
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
predictions = {
"id": doc["id"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability,
}
references = {
"id": doc["id"],
"answers": doc["answers"],
}
return {
"exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"exact": partial(
_squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
}
\ No newline at end of file
include: _template_yaml include: _template_yaml
task: squadv2 task: squadv2_generate_until
output_type: generate_until output_type: generate_until
generation_kwargs: generation_kwargs:
until: until:
......
group: squadv2_complete group: squadv2_complete
task: task:
- squadv2 - squadv2_generate_until
- squadv2_noans_loglikelihood - squadv2_noans_loglikelihood
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment