"git@developer.sourcefind.cn:laibao/llava_vllm.git" did not exist on "769c852d48c89d1d4ed6add38a481821b93d108f"
Unverified Commit 11f614b0 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge branch 'master' into task_doc

parents 0a6a9b7e e00d682f
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# NOTE: This is a modified version of https://github.com/huggingface/datasets/blob/master/datasets/wikitext/wikitext.py
# that returns Wiki pages instead of Wiki text line-by-line.
"""WikiText Dataset."""
import os
import datasets
_CITATION = """\
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
eprint={1609.07843},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_DESCRIPTION = """\
The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified
Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike
License.
"""
_HOMEPAGE = "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/"
_LICENSE = "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)"
_DATA_URL = "https://s3.amazonaws.com/research.metamind.io/wikitext"
class WikitextConfig(datasets.BuilderConfig):
"""BuilderConfig for GLUE."""
def __init__(self, data_url, **kwargs):
"""BuilderConfig for Wikitext
Args:
data_url: `string`, url to the dataset (word or raw level)
**kwargs: keyword arguments forwarded to super.
"""
super(WikitextConfig, self).__init__(
version=datasets.Version(
"1.0.0",
),
**kwargs,
)
self.data_url = data_url
class Wikitext(datasets.GeneratorBasedBuilder):
"""TODO(wikitext_103): Short description of my dataset."""
# TODO(wikitext_103): Set up version.
VERSION = datasets.Version("0.1.0")
BUILDER_CONFIGS = [
WikitextConfig(
name="wikitext-103-v1",
data_url=_DATA_URL + "/" + "wikitext-103-v1.zip",
description="Word level dataset. No processing is needed other than replacing newlines with <eos> tokens.",
),
WikitextConfig(
name="wikitext-2-v1",
data_url=_DATA_URL + "/" + "wikitext-2-v1.zip",
description="Word level dataset. No processing is needed other than replacing newlines with <eos> tokens.",
),
WikitextConfig(
name="wikitext-103-raw-v1",
data_url=_DATA_URL + "/" + "wikitext-103-raw-v1.zip",
description="Raw level dataset: the raw tokens before the addition of <unk> tokens. "
"They should only be used for character level work or for creating newly derived datasets.",
),
WikitextConfig(
name="wikitext-2-raw-v1",
data_url=_DATA_URL + "/" + "wikitext-2-raw-v1.zip",
description="Raw level dataset: the raw tokens before the addition of <unk> tokens. "
"They should only be used for character level work or for creating newly derived datasets.",
),
]
def _info(self):
# TODO(wikitext): Specifies the datasets.DatasetInfo object
return datasets.DatasetInfo(
# This is the description that will appear on the datasets page.
description=_DESCRIPTION,
# datasets.features.FeatureConnectors
features=datasets.Features(
{
"page": datasets.Value("string")
# These are the features of your dataset like images, labels ...
}
),
# If there's a common (input, target) tuple from the features,
# specify them here. They'll be used if as_supervised=True in
# builder.as_dataset.
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
# TODO(wikitext): Downloads the data and defines the splits
# dl_manager is a datasets.download.DownloadManager that can be used to
# download and extract URLs
if self.config.name == "wikitext-103-v1":
data_file = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.tokens"), "split": "train"},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.tokens"), "split": "valid"},
),
]
else:
if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
),
]
else:
if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
),
]
else:
if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
),
]
def _generate_examples(self, data_file, split):
"""Yields examples."""
with open(data_file, encoding="utf-8") as f:
key = 0
ret = []
data = f.read().split("\n")
for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
page = '\n'.join(ret)
if page.strip():
yield key, {"page": page}
key += 1
ret = []
ret.append(line)
page = '\n'.join(ret)
yield key, {"page": page}
...@@ -4,6 +4,9 @@ import json ...@@ -4,6 +4,9 @@ import json
import jsonlines import jsonlines
import io import io
import datetime import datetime
import mmap
import tqdm
from pathlib import Path
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
...@@ -59,16 +62,19 @@ class Reader: ...@@ -59,16 +62,19 @@ class Reader:
else: else:
yield text yield text
# Simple text reader and writer with same interface as above
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="ab"): def __init__(self, file_path, mode="rb+"):
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
if not os.path.exists(file_path):
Path(file_path).touch()
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}): def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n') self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self): def commit(self):
...@@ -79,12 +85,63 @@ class TextReader: ...@@ -79,12 +85,63 @@ class TextReader:
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
while True: while True:
line = self.fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
break break
else : else:
yield line[:-1] yield line[:-1]
\ No newline at end of file
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
\ No newline at end of file
import time
import random
import pickle
import json
import glob
import os
import collections
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated)
# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
# Build lookup for each dataset first in case we use different task combinations later
print("Building Lookups...")
start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {}
duplicates = {} # (task_name, task_set): set(doc_ids)}
sets_to_decontaminate = len(docs_by_task_set.keys())
for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(open(task_set_lookup_path, "rb"))
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
for doc_id, document in enumerate(docs):
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path,"wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
print(f"Building lookups took {elapsed:0.5f} seconds.")
matched_ngrams = []
if sets_to_decontaminate > 0:
print("Merging lookups...")
start = time.perf_counter()
merged_lookup = collections.defaultdict(list)
for (task_name, task_set), lookup in lookups.items():
for ngram, doc_ids in lookup.items():
merged_lookup[ngram].append((task_name, task_set, doc_ids))
elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
print(files)
for file in files:
start = time.perf_counter()
print(f"Scanning {file}")
reader = ZStdTextReader(file)
total_ngrams = 0
unique_ngrams = 0
matching_unique = 0
non_matching_unique = 0
current_ngram = ""
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram: # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
matched_ngrams.append(ngram) # For logging
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
non_matching_unique += 1
print(f"Total Ngrams: {total_ngrams}")
print(f"Unique Ngrams: {unique_ngrams}")
print(f"Unique Matching: {matching_unique}")
print(f"Unique Non Matching: {non_matching_unique}")
print("Matched ngrams:")
for ngram in matched_ngrams:
print(ngram)
elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
print(duplicates)
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
pickle.dump(doc_ids, open(overlaps_dump_path,"wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
...@@ -6,15 +6,18 @@ import lm_eval.metrics ...@@ -6,15 +6,18 @@ import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import lm_eval.decontamination
import numpy as np import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
from lm_eval.decontamination.decontaminate import get_train_overlap
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, check_integrity=False): description_dict=None, check_integrity=False,
decontamination_ngrams_path=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -72,7 +75,8 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -72,7 +75,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -90,9 +94,11 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -90,9 +94,11 @@ def simple_evaluate(model, model_args=None, tasks=[],
return results return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontamination_ngrams_path=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -120,6 +126,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -120,6 +126,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
...@@ -132,6 +140,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -132,6 +140,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
...@@ -140,6 +150,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -140,6 +150,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -147,7 +159,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -147,7 +159,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
task_doc_func = task.test_docs task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs(): elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs task_doc_func = task.validation_docs
else: else:
raise RuntimeError("Task has neither test_docs nor validation_docs") raise RuntimeError("Task has neither test_docs nor validation_docs")
...@@ -161,6 +175,10 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -161,6 +175,10 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = description_dict[task_name] if description_dict and task_name in description_dict else ""
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
...@@ -177,6 +195,11 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -177,6 +195,11 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -207,18 +230,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -207,18 +230,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
......
...@@ -3,7 +3,7 @@ from collections.abc import Iterable ...@@ -3,7 +3,7 @@ from collections.abc import Iterable
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn import sklearn.metrics
import random import random
...@@ -245,3 +245,10 @@ def stderr_for_metric(metric, bootstrap_iters): ...@@ -245,3 +245,10 @@ def stderr_for_metric(metric, bootstrap_iters):
} }
return stderr.get(metric, None) return stderr.get(metric, None)
def yesno(x):
if x:
return 'yes'
else:
return 'no'
...@@ -12,9 +12,14 @@ class HFLM(BaseLM): ...@@ -12,9 +12,14 @@ class HFLM(BaseLM):
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
if device: if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'")
else: else:
print("Device not specificed")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
......
...@@ -15,6 +15,7 @@ from . import wsc273 ...@@ -15,6 +15,7 @@ from . import wsc273
from . import winogrande from . import winogrande
from . import quac from . import quac
from . import hellaswag from . import hellaswag
from . import swag
from . import openbookqa from . import openbookqa
from . import squad from . import squad
from . import naturalqs from . import naturalqs
...@@ -50,6 +51,7 @@ from . import truthfulqa ...@@ -50,6 +51,7 @@ from . import truthfulqa
from . import blimp from . import blimp
from . import asdiv from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -135,8 +137,8 @@ TASK_REGISTRY = { ...@@ -135,8 +137,8 @@ TASK_REGISTRY = {
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, "hellaswag": hellaswag.HellaSwag,
"swag": swag.SWAG,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
...@@ -297,6 +299,11 @@ TASK_REGISTRY = { ...@@ -297,6 +299,11 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
} }
......
...@@ -10,9 +10,8 @@ provided explanations. ...@@ -10,9 +10,8 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli" Homepage: "https://github.com/facebookresearch/anli"
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean from lm_eval.metrics import mean
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -31,7 +30,7 @@ _CITATION = """ ...@@ -31,7 +30,7 @@ _CITATION = """
""" """
class ANLIBase(HFTask): class ANLIBase(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "anli" DATASET_PATH = "anli"
DATASET_NAME = None DATASET_NAME = None
...@@ -49,16 +48,16 @@ class ANLIBase(HFTask): ...@@ -49,16 +48,16 @@ class ANLIBase(HFTask):
def training_docs(self): def training_docs(self):
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.data["train_r" + str(self.SPLIT)]) self._training_docs = list(self.dataset["train_r" + str(self.SPLIT)])
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["dev_r" + str(self.SPLIT)] return self.dataset["dev_r" + str(self.SPLIT)]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_r" + str(self.SPLIT)] return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc): def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
...@@ -67,6 +66,12 @@ class ANLIBase(HFTask): ...@@ -67,6 +66,12 @@ class ANLIBase(HFTask):
# want to do it exactly as OA did? # want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["premise"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment # True = entailment
# False = contradiction # False = contradiction
...@@ -125,11 +130,14 @@ class ANLIBase(HFTask): ...@@ -125,11 +130,14 @@ class ANLIBase(HFTask):
"acc": True "acc": True
} }
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
SPLIT = 1 SPLIT = 1
class ANLIRound2(ANLIBase): class ANLIRound2(ANLIBase):
SPLIT = 2 SPLIT = 2
class ANLIRound3(ANLIBase): class ANLIRound3(ANLIBase):
SPLIT = 3 SPLIT = 3
...@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi ...@@ -13,7 +13,6 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc Homepage: https://allenai.org/data/arc
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
from . common import HFTask
_CITATION = """ _CITATION = """
...@@ -27,7 +26,7 @@ _CITATION = """ ...@@ -27,7 +26,7 @@ _CITATION = """
""" """
class ARCEasy(HFTask, MultipleChoiceTask): class ARCEasy(MultipleChoiceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
...@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -41,7 +40,18 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _convert_standard(self, doc): def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])
def test_docs(self):
return map(self._process_doc, self.dataset["test"])
def _process_doc(self, doc):
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one # NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters. # of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"} num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
...@@ -57,6 +67,12 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -57,6 +67,12 @@ class ARCEasy(HFTask, MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
......
...@@ -7,13 +7,10 @@ problem in natural language. ...@@ -7,13 +7,10 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data Homepage: https://github.com/openai/gpt-3/tree/master/data
""" """
import abc import inspect
import json import lm_eval.datasets.arithmetic.arithmetic
import os
from collections import namedtuple
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -31,33 +28,9 @@ _CITATION = """ ...@@ -31,33 +28,9 @@ _CITATION = """
""" """
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Task): class Arithmetic(Task):
VERSION = 0 VERSION = 0
directory = 'data/arithmetic/' DATASET_PATH = inspect.getfile(lm_eval.datasets.arithmetic.arithmetic)
def __init__(self):
super().__init__()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
self.set_docs()
@abc.abstractmethod
def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -72,25 +45,25 @@ class Arithmetic(Task): ...@@ -72,25 +45,25 @@ class Arithmetic(Task):
return NotImplemented return NotImplemented
def validation_docs(self): def validation_docs(self):
return self._docs return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc.context return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc.completion return doc["completion"]
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'].strip()
.replace('\n\n', '\n')
.replace('Q:', 'Question:')
.replace('A:', 'Answer:'), completion=doc_json['completion'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion) ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -111,41 +84,40 @@ class Arithmetic(Task): ...@@ -111,41 +84,40 @@ class Arithmetic(Task):
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2da"
return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
class Arithmetic2DMinus(Arithmetic): class Arithmetic2DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2ds"
return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
class Arithmetic3DPlus(Arithmetic): class Arithmetic3DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_3da"
return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
class Arithmetic3DMinus(Arithmetic): class Arithmetic3DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_3ds"
return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
class Arithmetic4DPlus(Arithmetic): class Arithmetic4DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_4da"
return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
class Arithmetic4DMinus(Arithmetic): class Arithmetic4DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_4ds"
return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
class Arithmetic5DPlus(Arithmetic): class Arithmetic5DPlus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_5da"
return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
class Arithmetic5DMinus(Arithmetic): class Arithmetic5DMinus(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_5ds"
return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
class Arithmetic2DMultiplication(Arithmetic): class Arithmetic2DMultiplication(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_2dm"
return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
class Arithmetic1DComposite(Arithmetic): class Arithmetic1DComposite(Arithmetic):
def get_file_download_info(self): DATASET_NAME = "arithmetic_1dc"
return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
...@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation. ...@@ -14,15 +14,10 @@ NOTE: We currently ignore formulas for answer generation.
Homepage: https://github.com/chaochun/nlu-asdiv-dataset Homepage: https://github.com/chaochun/nlu-asdiv-dataset
""" """
from lm_eval.base import Task import inspect
from pathlib import Path import lm_eval.datasets.asdiv.asdiv
from best_download import download_file from lm_eval.base import rf, Task
import xml.etree.ElementTree as ET from lm_eval.metrics import mean
from lm_eval.base import rf
from lm_eval.metrics import mean,perplexity
import numpy as np
from zipfile import ZipFile
import os
_CITATION = """ _CITATION = """
...@@ -39,39 +34,11 @@ _CITATION = """ ...@@ -39,39 +34,11 @@ _CITATION = """
class Asdiv(Task): class Asdiv(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = Path("data/asdiv") DATASET_PATH = inspect.getfile(lm_eval.datasets.asdiv.asdiv)
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum = "8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path = self.DATASET_PATH / "55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
os.remove(zip_path)
def _convert_standard(self, problem):
#TODO: include solution-type and formula
out_doc = {
"question" : problem.find('Question').text,
"body" : problem.find('Body').text,
"answer": problem.find('Answer').text
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
for pid, problem in enumerate(root.iter('Problem')):
out_doc = self._convert_standard(problem)
yield out_doc
def has_training_docs(self): def has_training_docs(self):
return False return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -81,13 +48,12 @@ class Asdiv(Task): ...@@ -81,13 +48,12 @@ class Asdiv(Task):
def training_docs(self): def training_docs(self):
raise NotImplementedError("This dataset has no training docs") raise NotImplementedError("This dataset has no training docs")
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def validation_docs(self):
data_xml_path = self.DATASET_PATH / "nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return self.load_docs(data_xml_path)
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
...@@ -101,6 +67,12 @@ class Asdiv(Task): ...@@ -101,6 +67,12 @@ class Asdiv(Task):
# TODO: add solution-type # TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:' return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['body'] + " " + doc['question']
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: add formula # TODO: add formula
......
...@@ -10,9 +10,8 @@ grammars. ...@@ -10,9 +10,8 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp Homepage: https://github.com/alexwarstadt/blimp
""" """
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -32,19 +31,24 @@ _CITATION = """ ...@@ -32,19 +31,24 @@ _CITATION = """
""" """
class BlimpTask(HFTask): class BlimpTask(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "blimp" DATASET_PATH = "blimp"
def download(self): def has_training_docs(self):
super().download() return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def validation_docs(self):
# The HF dataset only contains a "train" dataset, but the harness expects a "validation" # The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually # dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data. # trained on this data.
return self.dataset["train"]
self.data["validation"] = self.data["train"]
del self.data["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0 assert num_fewshot == 0
...@@ -64,6 +68,12 @@ class BlimpTask(HFTask): ...@@ -64,6 +68,12 @@ class BlimpTask(HFTask):
# this method is invoked by tests only # this method is invoked by tests only
return "" return ""
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence_good"] + " " + doc["sentence_bad"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this method is invoked by tests only # this method is invoked by tests only
return "" return ""
......
...@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4. ...@@ -13,9 +13,8 @@ used by the Recurrent Language Models described in the paper. See section 4.4.
Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt Homepage: https://github.com/facebookresearch/ParlAI/tree/main/parlai/tasks/cbt
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from lm_eval.metrics import mean from lm_eval.metrics import mean
from .common import HFTask
_CITATION = """ _CITATION = """
...@@ -30,11 +29,30 @@ _CITATION = """ ...@@ -30,11 +29,30 @@ _CITATION = """
""" """
class CBTBase(HFTask): class CBTBase(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "cbt" DATASET_PATH = "cbt"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def detokenize(self, text): def detokenize(self, text):
text = text.replace(" '", "'") text = text.replace(" '", "'")
...@@ -57,6 +75,13 @@ class CBTBase(HFTask): ...@@ -57,6 +75,13 @@ class CBTBase(HFTask):
text = "Passage: " + passage + "\nQuestion: " + doc["question"] text = "Passage: " + passage + "\nQuestion: " + doc["question"]
return self.detokenize(text) return self.detokenize(text)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
passage = " ".join(doc["sentences"])
return passage
def doc_to_target(self, doc): def doc_to_target(self, doc):
return "" return ""
......
import datasets
from ..base import Task
class HFTask(Task):
DATASET_PATH = None
DATASET_NAME = None
def __init__(self):
self.data = None
super().__init__()
def download(self):
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self):
"""Whether the task has a training set"""
return True if "train" in self.data.keys() else False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True if "validation" in self.data.keys() else False
def has_test_docs(self):
"""Whether the task has a test set"""
return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self):
# Cache training for faster few-shot.
# If data is too large to fit in memory, override this method.
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.data["train"]))
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return map(self._convert_standard, self.data["validation"])
def test_docs(self):
if self.has_test_docs():
return map(self._convert_standard, self.data["test"])
def yesno(x):
if x:
return 'yes'
else:
return 'no'
...@@ -9,13 +9,11 @@ appear in a conversation. ...@@ -9,13 +9,11 @@ appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/ Homepage: https://stanfordnlp.github.io/coqa/
""" """
import os import inspect
import json
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest from itertools import zip_longest
from best_download import download_file
_CITATION = """ _CITATION = """
...@@ -32,15 +30,8 @@ _CITATION = """ ...@@ -32,15 +30,8 @@ _CITATION = """
class CoQA(Task): class CoQA(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
def download(self): DATASET_NAME = None
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", local_file=coqa_train_filepath, expected_checksum="b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", local_file=coqa_dev_filepath, expected_checksum="dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -52,10 +43,10 @@ class CoQA(Task): ...@@ -52,10 +43,10 @@ class CoQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] return self.dataset["validation"]
def test_docs(self): def test_docs(self):
pass pass
...@@ -64,23 +55,29 @@ class CoQA(Task): ...@@ -64,23 +55,29 @@ class CoQA(Task):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n' question = f"Q: {q}\n\n"
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:" answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer doc_text += question + answer
return doc_text return doc_text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])
@classmethod @classmethod
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"] answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn) answers.append(answer_forturn)
additional_answers = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additional_answers: if additional_answers:
for key in additional_answers: for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"] additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
...@@ -120,8 +117,8 @@ class CoQA(Task): ...@@ -120,8 +117,8 @@ class CoQA(Task):
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]) turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers'][turnid - 1]["input_text"] raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -148,7 +145,7 @@ class CoQA(Task): ...@@ -148,7 +145,7 @@ class CoQA(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]) turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0] pred = results[0].strip().split('\n')[0]
......
...@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop ...@@ -12,16 +12,14 @@ Homepage: https://allenai.org/data/drop
Acknowledgement: This implementation is based on the official evaluation for `DROP`: Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
""" """
import json import inspect
import numpy as np import numpy as np
import re import re
import string import string
from best_download import download_file import lm_eval.datasets.drop.drop
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
from pathlib import Path
from zipfile import ZipFile
_CITATION = """ _CITATION = """
...@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) ...@@ -41,18 +39,8 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task): class DROP(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = Path("data/drop") DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path = self.DATASET_PATH / "drop_dataset.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -63,29 +51,46 @@ class DROP(Task): ...@@ -63,29 +51,46 @@ class DROP(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def _load_docs(self, docs): def training_docs(self):
for doc in docs: if self._training_docs is None:
for qa in doc["qa_pairs"]: self._training_docs = list(map(self._process_doc, self.dataset["train"]))
yield { return self._training_docs
"id": qa["query_id"],
"passage": doc["passage"], def validation_docs(self):
"question": qa["question"], return map(self._process_doc, self.dataset["validation"])
"answers": self.get_answers(qa),
} def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": self.get_answers(doc),
}
@classmethod @classmethod
def get_answers(cls, qa): def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
return vas
answers = [] answers = []
answers_set = set() answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
candidates = [qa["answer"]] + qa.get("validated_answers", [])
for candidate in candidates: for candidate in candidates:
answer = cls.parse_answer(candidate) answer = cls.parse_answer(candidate)
if answer in answers_set: if answer in answers_set:
continue continue
answers_set.add(answer) answers_set.add(answer)
answers.append(answer) answers.append(answer)
return answers return answers
@classmethod @classmethod
...@@ -99,17 +104,15 @@ class DROP(Task): ...@@ -99,17 +104,15 @@ class DROP(Task):
answer["date"]["month"], answer["date"]["month"],
answer["date"]["year"]]).strip(),) answer["date"]["year"]]).strip(),)
def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def validation_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_dev.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['passage'] + " " + doc['question']
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0]) return " " + ", ".join(doc["answers"][0])
......
...@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language. ...@@ -14,10 +14,9 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/ Homepage: https://gluebenchmark.com/
""" """
import numpy as np import numpy as np
from lm_eval.base import rf from lm_eval.base import rf, Task
from ..metrics import mean, matthews_corrcoef, f1_score from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from . common import HFTask, yesno from lm_eval.utils import general_detokenize
from ..utils import general_detokenize
# TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE. # TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE.
...@@ -46,7 +45,7 @@ _CITATION = """ ...@@ -46,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks # Single-Sentence Tasks
class CoLA(HFTask): class CoLA(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "cola" DATASET_NAME = "cola"
...@@ -60,9 +59,23 @@ class CoLA(HFTask): ...@@ -60,9 +59,23 @@ class CoLA(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format({1: "yes", 0: "no"}[doc["label"]]) return " {}".format({1: "yes", 0: "no"}[doc["label"]])
...@@ -90,7 +103,7 @@ class CoLA(HFTask): ...@@ -90,7 +103,7 @@ class CoLA(HFTask):
} }
class SST(HFTask): class SST(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "sst2" DATASET_NAME = "sst2"
...@@ -104,6 +117,14 @@ class SST(HFTask): ...@@ -104,6 +117,14 @@ class SST(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]), general_detokenize(doc["sentence"]),
...@@ -139,7 +160,7 @@ class SST(HFTask): ...@@ -139,7 +160,7 @@ class SST(HFTask):
# Inference Tasks # Inference Tasks
class MNLI(HFTask): class MNLI(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mnli" DATASET_NAME = "mnli"
...@@ -153,13 +174,18 @@ class MNLI(HFTask): ...@@ -153,13 +174,18 @@ class MNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation_matched"] return self.dataset["validation_matched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_matched"] return self.dataset["test_matched"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
...@@ -202,14 +228,14 @@ class MNLIMismatched(MNLI): ...@@ -202,14 +228,14 @@ class MNLIMismatched(MNLI):
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation_mismatched"] return self.dataset["validation_mismatched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_mismatched"] return self.dataset["test_mismatched"]
class QNLI(HFTask): class QNLI(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qnli" DATASET_NAME = "qnli"
...@@ -223,6 +249,14 @@ class QNLI(HFTask): ...@@ -223,6 +249,14 @@ class QNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"], doc["question"],
...@@ -258,7 +292,7 @@ class QNLI(HFTask): ...@@ -258,7 +292,7 @@ class QNLI(HFTask):
} }
class WNLI(HFTask): class WNLI(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
...@@ -272,6 +306,14 @@ class WNLI(HFTask): ...@@ -272,6 +306,14 @@ class WNLI(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
...@@ -307,7 +349,7 @@ class WNLI(HFTask): ...@@ -307,7 +349,7 @@ class WNLI(HFTask):
} }
class RTE(HFTask): class RTE(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "rte" DATASET_NAME = "rte"
...@@ -321,6 +363,14 @@ class RTE(HFTask): ...@@ -321,6 +363,14 @@ class RTE(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
...@@ -359,7 +409,7 @@ class RTE(HFTask): ...@@ -359,7 +409,7 @@ class RTE(HFTask):
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
class MRPC(HFTask): class MRPC(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mrpc" DATASET_NAME = "mrpc"
...@@ -373,6 +423,14 @@ class MRPC(HFTask): ...@@ -373,6 +423,14 @@ class MRPC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]), general_detokenize(doc["sentence1"]),
...@@ -409,7 +467,7 @@ class MRPC(HFTask): ...@@ -409,7 +467,7 @@ class MRPC(HFTask):
} }
class QQP(HFTask): class QQP(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qqp" DATASET_NAME = "qqp"
...@@ -423,6 +481,14 @@ class QQP(HFTask): ...@@ -423,6 +481,14 @@ class QQP(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format(
doc["question1"], doc["question1"],
...@@ -459,7 +525,7 @@ class QQP(HFTask): ...@@ -459,7 +525,7 @@ class QQP(HFTask):
} }
class STSB(HFTask): class STSB(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "stsb" DATASET_NAME = "stsb"
...@@ -473,6 +539,17 @@ class STSB(HFTask): ...@@ -473,6 +539,17 @@ class STSB(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "sentence 1: {}\nsentence 2: {}\nAnswer:".format( return "sentence 1: {}\nsentence 2: {}\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
......
...@@ -16,10 +16,9 @@ model's sample/generation function. ...@@ -16,10 +16,9 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math Homepage: https://github.com/openai/grade-school-math
""" """
import inspect
import json
import re import re
from best_download import download_file import lm_eval.datasets.gsm8k.gsm8k
from pathlib import Path from pathlib import Path
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]" ...@@ -43,21 +42,8 @@ INVALID_ANS = "[invalid]"
class GradeSchoolMath8K(Task): class GradeSchoolMath8K(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = Path('data/gsm8k') DATASET_PATH = inspect.getfile(lm_eval.datasets.gsm8k.gsm8k)
DATASET_NAME = None
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
base_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data"
splits = [
{"name": "train", "checksum": "17f347dc51477c50d4efb83959dbb7c56297aba886e5544ee2aaed3024813465"},
{"name": "test", "checksum": "3730d312f6e3440559ace48831e51066acaca737f6eabec99bccb9e4b3c39d14"},
]
for split in splits:
file = self.DATASET_PATH / f"{split['name']}.jsonl"
url = f"{base_url}/{split['name']}.jsonl"
download_file(url, local_file=str(file), expected_checksum=split["checksum"])
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task): ...@@ -68,17 +54,14 @@ class GradeSchoolMath8K(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, file):
return (json.loads(line) for line in open(file).read().splitlines())
def training_docs(self): def training_docs(self):
return self._load_docs(self.DATASET_PATH / "train.jsonl") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
raise NotImplementedError raise NotImplementedError
def test_docs(self): def test_docs(self):
return self._load_docs(self.DATASET_PATH / "test.jsonl") return self.dataset["test"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment