Unverified Commit 1044db95 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

update scrolls (#2602)

* update evaluate; update construct requests

* update construct requests to handle `apply_chat_template` kwarg
parent aa72104b
......@@ -4,7 +4,8 @@ from functools import reduce
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
from datasets import Dataset, load_metric
from datasets import Dataset
from evaluate import load
from transformers import AutoTokenizer
from lm_eval.api.instance import Instance
......@@ -48,7 +49,10 @@ def _download_metric():
from huggingface_hub import hf_hub_download
scrolls_metric_path = hf_hub_download(
repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py"
repo_id="tau/scrolls",
repo_type="dataset",
filename="metrics/scrolls.py",
revision="refs/pr/5",
)
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path)
......@@ -119,7 +123,7 @@ class _SCROLLSTask(ConfigurableTask):
def __init__(self, config=None):
super().__init__(config={"metadata": {"version": self.VERSION}})
if self.DATASET_NAME is not None:
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME)
self.metric = load(_download_metric(), config_name=self.DATASET_NAME)
def has_training_docs(self):
return True
......@@ -253,11 +257,14 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
}
def construct_requests(self, doc, ctx, **kwargs):
apply_chat_template = kwargs.pop("apply_chat_template", False)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
arguments=(ctx, " {}".format(choice))
if not apply_chat_template
else (ctx, "{}".format(choice)),
idx=i,
**kwargs,
)
......@@ -285,6 +292,7 @@ class _SCROLLSSummaryTask(_SCROLLSTask):
}
def construct_requests(self, doc, ctx, **kwargs):
kwargs.pop("apply_chat_template", False)
return Instance(
request_type="generate_until",
doc=doc,
......@@ -327,19 +335,22 @@ class Qasper(_SCROLLSTask):
return {"f1": (prediction, doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs):
apply_chat_template = kwargs.pop("apply_chat_template", False)
if doc["is_yes_no"]:
return [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " yes"),
arguments=(ctx, " yes")
if not apply_chat_template
else (ctx, "yes"),
idx=0,
**kwargs,
),
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " no"),
arguments=(ctx, " no") if not apply_chat_template else (ctx, "no"),
idx=1,
**kwargs,
),
......@@ -406,6 +417,7 @@ class NarrativeQA(_SCROLLSTask):
return {"f1": (results[0], doc["outputs"])}
def construct_requests(self, doc, ctx, **kwargs):
kwargs.pop("apply_chat_template", False)
return Instance(
request_type="generate_until",
doc=doc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment