Unverified Commit 1044db95 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

update scrolls (#2602)

* update evaluate; update construct requests

* update construct requests to handle `apply_chat_template` kwarg
parent aa72104b
...@@ -4,7 +4,8 @@ from functools import reduce ...@@ -4,7 +4,8 @@ from functools import reduce
import numpy as np import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from datasets import Dataset, load_metric from datasets import Dataset
from evaluate import load
from transformers import AutoTokenizer from transformers import AutoTokenizer
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -48,7 +49,10 @@ def _download_metric(): ...@@ -48,7 +49,10 @@ def _download_metric():
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
scrolls_metric_path = hf_hub_download( scrolls_metric_path = hf_hub_download(
repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py" repo_id="tau/scrolls",
repo_type="dataset",
filename="metrics/scrolls.py",
revision="refs/pr/5",
) )
updated_scrolls_metric_path = ( updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path) os.path.dirname(scrolls_metric_path)
...@@ -119,7 +123,7 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -119,7 +123,7 @@ class _SCROLLSTask(ConfigurableTask):
def __init__(self, config=None): def __init__(self, config=None):
super().__init__(config={"metadata": {"version": self.VERSION}}) super().__init__(config={"metadata": {"version": self.VERSION}})
if self.DATASET_NAME is not None: if self.DATASET_NAME is not None:
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME) self.metric = load(_download_metric(), config_name=self.DATASET_NAME)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -253,11 +257,14 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask): ...@@ -253,11 +257,14 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
} }
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
apply_chat_template = kwargs.pop("apply_chat_template", False)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice))
if not apply_chat_template
else (ctx, "{}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
...@@ -285,6 +292,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask): ...@@ -285,6 +292,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask):
} }
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
kwargs.pop("apply_chat_template", False)
return Instance( return Instance(
request_type="generate_until", request_type="generate_until",
doc=doc, doc=doc,
...@@ -327,19 +335,22 @@ class Qasper(_SCROLLSTask): ...@@ -327,19 +335,22 @@ class Qasper(_SCROLLSTask):
return {"f1": (prediction, doc["outputs"])} return {"f1": (prediction, doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
apply_chat_template = kwargs.pop("apply_chat_template", False)
if doc["is_yes_no"]: if doc["is_yes_no"]:
return [ return [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " yes"), arguments=(ctx, " yes")
if not apply_chat_template
else (ctx, "yes"),
idx=0, idx=0,
**kwargs, **kwargs,
), ),
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " no"), arguments=(ctx, " no") if not apply_chat_template else (ctx, "no"),
idx=1, idx=1,
**kwargs, **kwargs,
), ),
...@@ -406,6 +417,7 @@ class NarrativeQA(_SCROLLSTask): ...@@ -406,6 +417,7 @@ class NarrativeQA(_SCROLLSTask):
return {"f1": (results[0], doc["outputs"])} return {"f1": (results[0], doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
kwargs.pop("apply_chat_template", False)
return Instance( return Instance(
request_type="generate_until", request_type="generate_until",
doc=doc, doc=doc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment