Commit 1f8a8c1d authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into remove-dataset

parents b4c0275d b0acb337
{"wikitext-103-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-103-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1281262, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 539297488, "num_examples": 29444, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1142488, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip": {"num_bytes": 190229076, "checksum": "242ba0f20b329cfdf1ccc61e9e9e5b59becf189db7f7a81cd2a0e2fc31539590"}}, "download_size": 190229076, "post_processing_size": null, "dataset_size": 541721238, "size_in_bytes": 731950314}, "wikitext-2-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-2-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1256634, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 10799034, "num_examples": 629, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1121860, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip": {"num_bytes": 4475746, "checksum": "92675f1d63015c1c8b51f1656a52d5bdbc33aafa60cc47a218a66e7ee817488c"}}, "download_size": 4475746, "post_processing_size": null, "dataset_size": 13177528, "size_in_bytes": 17653274}, "wikitext-103-raw-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-103-raw-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1290775, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 540656522, "num_examples": 29444, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1147025, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip": {"num_bytes": 191984949, "checksum": "91c00ae287f0d699e18605c84afc9e45c192bc6b7797ff8837e5474655a33794"}}, "download_size": 191984949, "post_processing_size": null, "dataset_size": 543094322, "size_in_bytes": 735079271}, "wikitext-2-raw-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-2-raw-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1290775, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 10942633, "num_examples": 629, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1147025, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip": {"num_bytes": 4721645, "checksum": "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11"}}, "download_size": 4721645, "post_processing_size": null, "dataset_size": 13380433, "size_in_bytes": 18102078}} {"wikitext-103-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-103-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1281262, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 539297488, "num_examples": 29444, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1142488, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip": {"num_bytes": 190229076, "checksum": "242ba0f20b329cfdf1ccc61e9e9e5b59becf189db7f7a81cd2a0e2fc31539590"}}, "download_size": 190229076, "post_processing_size": null, "dataset_size": 541721238, "size_in_bytes": 731950314}, "wikitext-2-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-2-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1256634, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 10799034, "num_examples": 629, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1121860, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip": {"num_bytes": 4475746, "checksum": "92675f1d63015c1c8b51f1656a52d5bdbc33aafa60cc47a218a66e7ee817488c"}}, "download_size": 4475746, "post_processing_size": null, "dataset_size": 13177528, "size_in_bytes": 17653274}, "wikitext-103-raw-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-103-raw-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1290775, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 540656522, "num_examples": 29444, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1147025, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip": {"num_bytes": 191984949, "checksum": "91c00ae287f0d699e18605c84afc9e45c192bc6b7797ff8837e5474655a33794"}}, "download_size": 191984949, "post_processing_size": null, "dataset_size": 543094322, "size_in_bytes": 735079271}, "wikitext-2-raw-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-2-raw-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1290775, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 10942633, "num_examples": 629, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1147025, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip": {"num_bytes": 4721645, "checksum": "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11"}}, "download_size": 4721645, "post_processing_size": null, "dataset_size": 13380433, "size_in_bytes": 18102078}}
\ No newline at end of file
...@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder): ...@@ -123,86 +123,111 @@ class Wikitext(datasets.GeneratorBasedBuilder):
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.tokens"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.tokens"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.tokens"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.tokens"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-103-raw-v1": if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(self.config.data_url)
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw") data_dir = os.path.join(data_file, "wikitext-103-raw")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.raw"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.raw"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.raw"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-2-raw-v1": if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(self.config.data_url)
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw") data_dir = os.path.join(data_file, "wikitext-2-raw")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.raw"), "split": "test"}, "data_file": os.path.join(data_dir, "wiki.test.raw"),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.train.raw"), "split": "train"}, "data_file": os.path.join(data_dir, "wiki.train.raw"),
"split": "train",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.valid.raw"), "split": "valid"}, "data_file": os.path.join(data_dir, "wiki.valid.raw"),
"split": "valid",
},
), ),
] ]
else: else:
if self.config.name == "wikitext-2-v1": if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract( data_file = dl_manager.download_and_extract(
self.config.data_url) self.config.data_url
)
data_dir = os.path.join(data_file, "wikitext-2") data_dir = os.path.join(data_file, "wikitext-2")
return [ return [
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join( gen_kwargs={
data_dir, "wiki.test.tokens"), "split": "test"}, "data_file": os.path.join(
data_dir, "wiki.test.tokens"
),
"split": "test",
},
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={ gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"), "data_file": os.path.join(
data_dir, "wiki.train.tokens"
),
"split": "train", "split": "train",
}, },
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.VALIDATION, name=datasets.Split.VALIDATION,
gen_kwargs={ gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"), "data_file": os.path.join(
data_dir, "wiki.valid.tokens"
),
"split": "valid", "split": "valid",
}, },
), ),
...@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder): ...@@ -216,12 +241,12 @@ class Wikitext(datasets.GeneratorBasedBuilder):
data = f.read().split("\n") data = f.read().split("\n")
for line in data: for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip() rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='): if rline.startswith("= ") and rline.strip().endswith(" ="):
page = '\n'.join(ret) page = "\n".join(ret)
if page.strip(): if page.strip():
yield key, {"page": page} yield key, {"page": page}
key += 1 key += 1
ret = [] ret = []
ret.append(line) ret.append(line)
page = '\n'.join(ret) page = "\n".join(ret)
yield key, {"page": page} yield key, {"page": page}
...@@ -4,13 +4,18 @@ import json ...@@ -4,13 +4,18 @@ import json
import jsonlines import jsonlines
import io import io
import datetime import datetime
import mmap
import tqdm
from pathlib import Path
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)): if isinstance(obj, (datetime.datetime,)):
return obj.isoformat() return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj)) raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file. # Modified version of lm_dataformat Archive for single file.
class Archive: class Archive:
...@@ -18,26 +23,32 @@ class Archive: ...@@ -18,26 +23,32 @@ class Archive:
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb') self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level) self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh) self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}): def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n') self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self): def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME) self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader: class Reader:
def __init__(self): def __init__(self):
pass pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'): def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
with open(file, 'rb') as fh: with open(file, "rb") as fh:
self.fh = fh self.fh = fh
cctx = zstandard.ZstdDecompressor() cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh)) reader = io.BufferedReader(cctx.stream_reader(fh))
...@@ -49,42 +60,102 @@ class Reader: ...@@ -49,42 +60,102 @@ class Reader:
yield ob yield ob
continue continue
text = ob['text'] text = ob["text"]
if autojoin_paragraphs and isinstance(text, list): if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text) text = para_joiner.join(text)
if get_meta: if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {}) yield text, (ob["meta"] if "meta" in ob else {})
else: else:
yield text yield text
# Simple text reader and writer with same interface as above
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="ab"): def __init__(self, file_path, mode="rb+"):
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
if not os.path.exists(file_path):
def add_data(self, data, meta={}): Path(file_path).touch()
self.fh.write(data.encode('UTF-8') + b'\n')
self.fh = open(self.file_path, mode)
def add_data(self, data):
self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self): def commit(self):
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
class TextReader: class TextReader:
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, "r", encoding="utf8") as fh:
self.fh = fh with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
while True: while True:
line = self.fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
break break
else : else:
yield line[:-1] yield line[:-1]
\ No newline at end of file
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
import time
import random
import pickle
import json
import glob
import os
import collections
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated)
# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
# Build lookup for each dataset first in case we use different task combinations later
print("Building Lookups...")
start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {}
duplicates = {} # (task_name, task_set): set(doc_ids)}
sets_to_decontaminate = len(docs_by_task_set.keys())
for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
for doc_id, document in enumerate(docs):
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
print(f"Building lookups took {elapsed:0.5f} seconds.")
matched_ngrams = []
if sets_to_decontaminate > 0:
print("Merging lookups...")
start = time.perf_counter()
merged_lookup = collections.defaultdict(list)
for (task_name, task_set), lookup in lookups.items():
for ngram, doc_ids in lookup.items():
merged_lookup[ngram].append((task_name, task_set, doc_ids))
elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
print(files)
for file in files:
start = time.perf_counter()
print(f"Scanning {file}")
reader = ZStdTextReader(file)
total_ngrams = 0
unique_ngrams = 0
matching_unique = 0
non_matching_unique = 0
current_ngram = ""
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
matched_ngrams.append(ngram) # For logging
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for (
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
non_matching_unique += 1
print(f"Total Ngrams: {total_ngrams}")
print(f"Unique Ngrams: {unique_ngrams}")
print(f"Unique Matching: {matching_unique}")
print(f"Unique Non Matching: {non_matching_unique}")
print("Matched ngrams:")
for ngram in matched_ngrams:
print(ngram)
elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
print(duplicates)
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
...@@ -9,8 +9,9 @@ from pprint import pprint ...@@ -9,8 +9,9 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try: try:
import janitor_util import janitor_util
JANITOR_CPP = True JANITOR_CPP = True
except Exception as e: except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode") print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
...@@ -41,6 +42,7 @@ def word_ngrams(s, n): ...@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs) return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later # Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n): # def word_ngrams_indices_combined(sequence, n):
# current_word = "" # current_word = ""
...@@ -70,7 +72,7 @@ def split_indices(s): ...@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s, n):
...@@ -90,22 +92,27 @@ def word_ngrams_indices(s, n): ...@@ -90,22 +92,27 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]), # ([word, word, ...], [(start,end), (start,end), ...]),
# ... # ...
# ) # )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices) ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs) return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor: class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars? # FIXME delete_chars: Should anything else go here? Special chars?
def __init__( def __init__(
self, self,
ngram_n=13, ngram_n=13,
window_to_remove=200, window_to_remove=200,
too_dirty_cutoff=10, too_dirty_cutoff=10,
minimum_slice_length=200, minimum_slice_length=200,
delete_chars=string.punctuation delete_chars=string.punctuation,
): ):
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
...@@ -121,7 +128,7 @@ class Janitor: ...@@ -121,7 +128,7 @@ class Janitor:
self.translation_table = str.maketrans( self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted self.delete_chars, # These are deleted
) )
############## ##############
...@@ -129,14 +136,13 @@ class Janitor: ...@@ -129,14 +136,13 @@ class Janitor:
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
############## ##############
# Call these :) # Call these :)
############## ##############
...@@ -152,7 +158,7 @@ class Janitor: ...@@ -152,7 +158,7 @@ class Janitor:
def clean(self, dirty_string): def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
if JANITOR_CPP: if JANITOR_CPP:
return self.clean_cpp(dirty_string) return self.clean_cpp(dirty_string)
...@@ -171,11 +177,11 @@ class Janitor: ...@@ -171,11 +177,11 @@ class Janitor:
end = min(len(dirty_string), end + self.window_to_remove) end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length: if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start]) clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length: if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:]) clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks return clean_chunks
...@@ -184,10 +190,14 @@ class Janitor: ...@@ -184,10 +190,14 @@ class Janitor:
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)) self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n) contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices) return self._split_chunks(dirty_string, contamination_indices)
############## ##############
...@@ -198,7 +208,9 @@ class Janitor: ...@@ -198,7 +208,9 @@ class Janitor:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n)) self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string): def clean_python(self, dirty_string):
contamination_indices = ( contamination_indices = (
...@@ -249,29 +261,29 @@ class Janitor: ...@@ -249,29 +261,29 @@ class Janitor:
# data = f.read() # data = f.read()
# jan = Janitor(too_dirty_cutoff=1000) # jan = Janitor(too_dirty_cutoff=1000)
# jan.register_contaminant(''' # jan.register_contaminant('''
# theories is that there is a connection between &quot;geekdom&quot; and autism. # theories is that there is a connection between &quot;geekdom&quot; and autism.
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot; # This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
# The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights # The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of # movement{{ref|Wired}}. This article, many professionals assert, is just one example of
# the media's application of mental disease labels to what is actually variant normal behavior # the media's application of mental disease labels to what is actually variant normal behavior
# &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual # &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
# interests, even when they seem unusual to others, are not in themselves signs of autism or # interests, even when they seem unusual to others, are not in themselves signs of autism or
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying # Asperger's syndrome. Others assert that it is actually the medical profession which is applying
# mental disease labels to children who in the past would have simply been accepted as a little # mental disease labels to children who in the past would have simply been accepted as a little
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue. # different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
# Due to the recent publicity surrounding autism and autis # Due to the recent publicity surrounding autism and autis
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, prefering to save the revenue rather than investing it in # would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M, # with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995), # ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the # ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the # Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
# [[United Arab Emirates]]. After the Emirates gained independence in 1971, # [[United Arab Emirates]]. After the Emirates gained independence in 1971,
# ''') # ''')
# """ # """
......
import collections import collections
import itertools import itertools
import pathlib import numpy as np
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(
num_fewshot=0, batch_size=None, device=None, model,
no_cache=False, limit=None, bootstrap_iters=100000, model_args=None,
description_dict=None, check_integrity=False): tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str] :param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string. String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object. Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]] :param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...@@ -37,7 +47,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -37,7 +47,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:return :return
...@@ -49,19 +59,25 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -49,19 +59,25 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
if isinstance(model, str): if isinstance(model, str):
if model_args is None: model_args = "" if model_args is None:
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { model_args = ""
'batch_size': batch_size, 'device': device lm = lm_eval.models.get_model(model).create_from_arg_string(
}) model_args, {"batch_size": batch_size, "device": device}
)
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
if check_integrity: if check_integrity:
...@@ -72,7 +88,9 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -72,7 +88,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -85,14 +103,26 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -85,14 +103,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict "description_dict": description_dict,
} }
return results return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -108,7 +138,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -108,7 +138,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -118,12 +148,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -118,12 +148,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented. assert not provide_description # not implemented.
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs()) if (task.has_validation_docs() or task.has_test_docs())
] ]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -132,6 +166,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -132,6 +166,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
...@@ -140,6 +176,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -140,6 +176,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -147,7 +185,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -147,7 +185,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
task_doc_func = task.test_docs task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs(): elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs task_doc_func = task.validation_docs
else: else:
raise RuntimeError("Task has neither test_docs nor validation_docs") raise RuntimeError("Task has neither test_docs nor validation_docs")
...@@ -158,15 +198,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -158,15 +198,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
...@@ -177,6 +224,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -177,6 +224,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -189,11 +245,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -189,11 +245,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
...@@ -207,25 +265,36 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -207,25 +265,36 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
return { return {"results": dict(results), "versions": dict(versions)}
"results": dict(results),
"versions": dict(versions)
}
def make_table(result_dict): def make_table(result_dict):
...@@ -247,9 +316,9 @@ def make_table(result_dict): ...@@ -247,9 +316,9 @@ def make_table(result_dict):
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append([k, version, m, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
...@@ -103,6 +103,7 @@ def weighted_mean(items): ...@@ -103,6 +103,7 @@ def weighted_mean(items):
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
def bits_per_byte(items): def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
...@@ -184,8 +185,10 @@ def _sacreformat(refs, preds): ...@@ -184,8 +185,10 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
...@@ -203,9 +206,10 @@ class _bootstrap_internal: ...@@ -203,9 +206,10 @@ class _bootstrap_internal:
def bootstrap_stderr(f, xs, iters): def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp import multiprocessing as mp
pool = mp.Pool(mp.cpu_count()) pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev. # equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is # Unfortunately, I haven't been able to figure out what the right correction is
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
...@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters): ...@@ -213,10 +217,15 @@ def bootstrap_stderr(f, xs, iters):
res = [] res = []
chunk_size = min(1000, iters) chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap( for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size), _bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size): [(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement # sample w replacement
res.extend(bootstrap) res.extend(bootstrap)
...@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters): ...@@ -238,17 +247,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if metric in bootstrappable: if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = { stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None) return stderr.get(metric, None)
def yesno(x): def yesno(x):
if x: if x:
return 'yes' return "yes"
else: else:
return 'no' return "no"
...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM ...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__(
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -13,30 +20,54 @@ class HFLM(BaseLM): ...@@ -13,30 +20,54 @@ class HFLM(BaseLM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
if device: if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'")
else: else:
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, ( assert isinstance(
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, self.tokenizer,
transformers.T5Tokenizer, transformers.T5TokenizerFast, (
)), "this tokenizer has not been checked for compatibility yet!" transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): if isinstance(
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
self.tokenizer.encode('hello\n\nhello') ):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
...@@ -75,7 +106,7 @@ class HFLM(BaseLM): ...@@ -75,7 +106,7 @@ class HFLM(BaseLM):
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
...@@ -89,13 +120,10 @@ class HFLM(BaseLM): ...@@ -89,13 +120,10 @@ class HFLM(BaseLM):
""" """
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257] return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
) )
......
...@@ -31,22 +31,24 @@ def get_result(response, ctxlen): ...@@ -31,22 +31,24 @@ def get_result(response, ctxlen):
if top_token != token: if top_token != token:
is_greedy = False is_greedy = False
break break
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
def oa_completion(**kwargs): def oa_completion(**kwargs):
""" Query OpenAI API for completion. """Query OpenAI API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import openai import openai
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
return openai.Completion.create(**kwargs) return openai.Completion.create(**kwargs)
except openai.error.OpenAIError: except openai.error.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM): ...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super().__init__() super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away # to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0] self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
...@@ -105,7 +110,7 @@ class GPT3LM(BaseLM): ...@@ -105,7 +110,7 @@ class GPT3LM(BaseLM):
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
...@@ -118,17 +123,22 @@ class GPT3LM(BaseLM): ...@@ -118,17 +123,22 @@ class GPT3LM(BaseLM):
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm): re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = [] inps = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token # max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length+1):] inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens # TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1)) ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM): ...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
echo=True, echo=True,
max_tokens=0, temperature=0., max_tokens=0,
temperature=0.0,
logprobs=10, logprobs=10,
) )
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk): for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen) answer = get_result(resp, ctxlen)
res.append(answer) res.append(answer)
...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM): ...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM):
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
...@@ -160,8 +173,8 @@ class GPT3LM(BaseLM): ...@@ -160,8 +173,8 @@ class GPT3LM(BaseLM):
def _collate(x): def _collate(x):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return len(toks), x[0] return len(toks), x[0]
reord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size): def sameuntil_chunks(xs, size):
ret = [] ret = []
...@@ -172,39 +185,41 @@ class GPT3LM(BaseLM): ...@@ -172,39 +185,41 @@ class GPT3LM(BaseLM):
ret = [] ret = []
lastuntil = x[1] lastuntil = x[1]
ret.append(x) ret.append(x)
if ret: if ret:
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
temperature=0., temperature=0.0,
logprobs=10, logprobs=10,
stop=until, stop=until,
) )
for resp, (context, until_) in zip(response.choices, chunk): for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text'] s = resp["text"]
for term in until_: for term in until_:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching # partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s) self.cache_hook.add_partial("greedy_until", (context, until_), s)
res.append(s) res.append(s)
return reord.get_original(res) return re_ord.get_original(res)
def _model_call(self, inps): def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
......
...@@ -22,14 +22,12 @@ from . import naturalqs ...@@ -22,14 +22,12 @@ from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import race
from . import piqa from . import piqa
from . import prost from . import prost
from . import mc_taco from . import mc_taco
from . import triviaqa from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq from . import sciq
from . import webqs
from . import qasper from . import qasper
from . import qa4mre from . import qa4mre
from . import translation from . import translation
...@@ -59,8 +57,8 @@ from . import storycloze ...@@ -59,8 +57,8 @@ from . import storycloze
# 6 total # 6 total
gpt3_translation_benchmarks = { gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ["en-fr", "fr-en"], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
} }
...@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = { ...@@ -68,7 +66,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = { selected_translation_benchmarks = {
**gpt3_translation_benchmarks, **gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ["en-ar", "ar-en"], # Arabic
} }
# 319 total # 319 total
...@@ -92,7 +90,7 @@ TASK_REGISTRY = { ...@@ -92,7 +90,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet # "stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -103,34 +101,26 @@ TASK_REGISTRY = { ...@@ -103,34 +101,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada # multilingual lambada
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST, "prost": prost.PROST,
"mc_taco": mc_taco.MCTACO, "mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq": sciq.SciQ,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2013": qa4mre.QA4MRE_2013,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
...@@ -142,7 +132,7 @@ TASK_REGISTRY = { ...@@ -142,7 +132,7 @@ TASK_REGISTRY = {
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs, "headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn, "headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA, "mathqa": mathqa.MathQA,
...@@ -152,21 +142,17 @@ TASK_REGISTRY = { ...@@ -152,21 +142,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
...@@ -177,7 +163,6 @@ TASK_REGISTRY = { ...@@ -177,7 +163,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv, "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K, "gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
...@@ -191,22 +176,18 @@ TASK_REGISTRY = { ...@@ -191,22 +176,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks) # hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(), **hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20 # chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks), **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks # Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1, "anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2, "anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile # Pile
"pile_arxiv": pile.PileArxiv, "pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3, "pile_books3": pile.PileBooks3,
...@@ -230,7 +211,6 @@ TASK_REGISTRY = { ...@@ -230,7 +211,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP # BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
...@@ -299,7 +279,6 @@ TASK_REGISTRY = { ...@@ -299,7 +279,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
...@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY)) ...@@ -313,7 +292,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name): def get_task(task_name):
try: try:
return TASK_REGISTRY[task_name] return TASK_REGISTRY[task_name]
except KeyError as e: except KeyError:
print("Available tasks:") print("Available tasks:")
pprint(TASK_REGISTRY) pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}") raise KeyError(f"Missing task {task_name}")
...@@ -323,19 +302,25 @@ def get_task_name_from_object(task_object): ...@@ -323,19 +302,25 @@ def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items(): for name, class_ in TASK_REGISTRY.items():
if class_ is task_object: if class_ is task_object:
return name return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list
if isinstance(task_name, str)
} }
task_name_from_object_dict = { task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list
if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict} return {**task_name_dict, **task_name_from_object_dict}
...@@ -61,36 +61,47 @@ class ANLIBase(Task): ...@@ -61,36 +61,47 @@ class ANLIBase(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really* # appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did? # want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["premise"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment # True = entailment
# False = contradiction # False = contradiction
# Neither = neutral # Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']] return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
ll_true, _ = rf.loglikelihood(ctx, " True") ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither") ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False") ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false return ll_true, ll_neither, ll_false
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -100,29 +111,23 @@ class ANLIBase(Task): ...@@ -100,29 +111,23 @@ class ANLIBase(Task):
""" """
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
......
...@@ -67,6 +67,12 @@ class ARCEasy(MultipleChoiceTask): ...@@ -67,6 +67,12 @@ class ARCEasy(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
......
...@@ -49,10 +49,16 @@ class Arithmetic(Task): ...@@ -49,10 +49,16 @@ class Arithmetic(Task):
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc["completion"] return doc["completion"]
...@@ -61,10 +67,8 @@ class Arithmetic(Task): ...@@ -61,10 +67,8 @@ class Arithmetic(Task):
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction, = results (is_prediction,) = results
return { return {"acc": is_prediction}
"acc": is_prediction
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -72,9 +76,7 @@ class Arithmetic(Task): ...@@ -72,9 +76,7 @@ class Arithmetic(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
......
...@@ -54,42 +54,41 @@ class Asdiv(Task): ...@@ -54,42 +54,41 @@ class Asdiv(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: add solution-type # TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:' return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: add formula # TODO: add formula
answer = doc['answer'].split(' (')[0] answer = doc["answer"].split(" (")[0]
return " " + answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc)) ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy return ll, is_greedy
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_greedy = results ll, is_greedy = results
return { return {"acc": int(is_greedy)}
'acc': int(is_greedy)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
...@@ -28,7 +28,7 @@ _CITATION = """ ...@@ -28,7 +28,7 @@ _CITATION = """
eprint = {https://doi.org/10.1162/tacl_a_00321}, eprint = {https://doi.org/10.1162/tacl_a_00321},
abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. } abstract = { We introduce The Benchmark of Linguistic Minimal Pairs (BLiMP),1 a challenge set for evaluating the linguistic knowledge of language models (LMs) on major grammatical phenomena in English. BLiMP consists of 67 individual datasets, each containing 1,000 minimal pairs—that is, pairs of minimally different sentences that contrast in grammatical acceptability and isolate specific phenomenon in syntax, morphology, or semantics. We generate the data according to linguist-crafted grammar templates, and human aggregate agreement with the labels is 96.4\%. We evaluate n-gram, LSTM, and Transformer (GPT-2 and Transformer-XL) LMs by observing whether they assign a higher probability to the acceptable sentence in each minimal pair. We find that state-of-the-art models identify morphological contrasts related to agreement reliably, but they struggle with some subtle semantic and syntactic phenomena, such as negative polarity items and extraction islands. }
} }
""" """ # noqa: W605
class BlimpTask(Task): class BlimpTask(Task):
...@@ -37,7 +37,7 @@ class BlimpTask(Task): ...@@ -37,7 +37,7 @@ class BlimpTask(Task):
def has_training_docs(self): def has_training_docs(self):
return False return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -50,9 +50,13 @@ class BlimpTask(Task): ...@@ -50,9 +50,13 @@ class BlimpTask(Task):
# trained on this data. # trained on this data.
return self.dataset["train"] return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0 assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -60,7 +64,9 @@ class BlimpTask(Task): ...@@ -60,7 +64,9 @@ class BlimpTask(Task):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return "" return ""
...@@ -68,6 +74,12 @@ class BlimpTask(Task): ...@@ -68,6 +74,12 @@ class BlimpTask(Task):
# this method is invoked by tests only # this method is invoked by tests only
return "" return ""
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence_good"] + " " + doc["sentence_bad"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this method is invoked by tests only # this method is invoked by tests only
return "" return ""
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
The Children’s Book Test (CBT) from the paper: The Children’s Book Test (CBT) from the paper:
https://research.fb.com/wp-content/uploads/2016/11/the_goldilocks_principle_reading_children_s_books_with_explicit_memory_representations.pdf https://research.fb.com/wp-content/uploads/2016/11/the_goldilocks_principle_reading_children_s_books_with_explicit_memory_representations.pdf
The Children's Book Test (CBT) is test of how well language models capture The Children's Book Test (CBT) is test of how well language models capture
meaning in children's books. Unlike standard language modelling benchmarks, meaning in children's books. Unlike standard language modelling benchmarks,
it distinguishes the task of predicting syntactic function words from that it distinguishes the task of predicting syntactic function words from that
of predicting lower-frequency words, which carry greater semantic content. of predicting lower-frequency words, which carry greater semantic content.
...@@ -19,7 +19,7 @@ from lm_eval.metrics import mean ...@@ -19,7 +19,7 @@ from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@misc{hill2016goldilocks, @misc{hill2016goldilocks,
title={The Goldilocks Principle: Reading Children's Books with Explicit Memory Representations}, title={The Goldilocks Principle: Reading Children's Books with Explicit Memory Representations},
author={Felix Hill and Antoine Bordes and Sumit Chopra and Jason Weston}, author={Felix Hill and Antoine Bordes and Sumit Chopra and Jason Weston},
year={2016}, year={2016},
eprint={1511.02301}, eprint={1511.02301},
...@@ -75,11 +75,20 @@ class CBTBase(Task): ...@@ -75,11 +75,20 @@ class CBTBase(Task):
text = "Passage: " + passage + "\nQuestion: " + doc["question"] text = "Passage: " + passage + "\nQuestion: " + doc["question"]
return self.detokenize(text) return self.detokenize(text)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
passage = " ".join(doc["sentences"])
return passage
def doc_to_target(self, doc): def doc_to_target(self, doc):
return "" return ""
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
assert k == 0, f"CBT is only implemented for the zero-shot setting. Given k={k}." assert (
k == 0
), f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd) return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -113,9 +122,7 @@ class CBTBase(Task): ...@@ -113,9 +122,7 @@ class CBTBase(Task):
""" """
gold = doc["options"].index(doc["answer"]) gold = doc["options"].index(doc["answer"])
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -123,9 +130,7 @@ class CBTBase(Task): ...@@ -123,9 +130,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -133,9 +138,7 @@ class CBTBase(Task): ...@@ -133,9 +138,7 @@ class CBTBase(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class CBTCN(CBTBase): class CBTCN(CBTBase):
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
CoQA: A Conversational Question Answering Challenge CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf https://arxiv.org/pdf/1808.07042.pdf
CoQA is a large-scale dataset for building Conversational Question Answering CoQA is a large-scale dataset for building Conversational Question Answering
systems. The goal of the CoQA challenge is to measure the ability of machines to systems. The goal of the CoQA challenge is to measure the ability of machines to
understand a text passage and answer a series of interconnected questions that understand a text passage and answer a series of interconnected questions that
appear in a conversation. appear in a conversation.
Homepage: https://stanfordnlp.github.io/coqa/ Homepage: https://stanfordnlp.github.io/coqa/
...@@ -52,43 +52,53 @@ class CoQA(Task): ...@@ -52,43 +52,53 @@ class CoQA(Task):
pass pass
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai for (q, a) in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n" question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:" answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer doc_text += question + answer
return doc_text return doc_text
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])
@classmethod @classmethod
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1] answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn) answers.append(answer_forturn)
additional_answers = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additional_answers: if additional_answers:
for key in additional_answers: for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1] additional_answer_for_turn = additional_answers[key]["input_text"][
turn_id - 1
]
if additional_answer_for_turn.lower() not in map(str.lower, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
@classmethod @classmethod
def get_answer_choice(self, raw_text): def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories # Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No # ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based # ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch) # (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown": if raw_text == "unknown":
return '0' return "0"
if squad_metrics.normalize_answer(raw_text) == "yes": if squad_metrics.normalize_answer(raw_text) == "yes":
return '1' return "1"
if squad_metrics.normalize_answer(raw_text) == "no": if squad_metrics.normalize_answer(raw_text) == "no":
return '2' return "2"
return '3' # Not a yes/no question return "3" # Not a yes/no question
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
...@@ -98,40 +108,45 @@ class CoQA(Task): ...@@ -98,40 +108,45 @@ class CoQA(Task):
em_sum = 0.0 em_sum = 0.0
if len(gold_list) > 1: if len(gold_list) > 1:
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None): def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]["input_text"]) turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1] raw_text = doc["answers"]["input_text"][turnid - 1]
return " " + raw_text return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
cont_request = rf.greedy_until(ctx, ['\nQ:']) cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -141,13 +156,13 @@ class CoQA(Task): ...@@ -141,13 +156,13 @@ class CoQA(Task):
""" """
turn_id = len(doc["questions"]["input_text"]) turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0] pred = results[0].strip().split("\n")[0]
scores = self.compute_scores(gold_list, pred) scores = self.compute_scores(gold_list, pred)
return { return {
"f1": scores['f1'], "f1": scores["f1"],
"em": scores['em'], "em": scores["em"],
} }
def higher_is_better(self): def higher_is_better(self):
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
https://aclanthology.org/attachments/N19-1246.Supplementary.pdf https://aclanthology.org/attachments/N19-1246.Supplementary.pdf
DROP is a QA dataset which tests comprehensive understanding of paragraphs. In DROP is a QA dataset which tests comprehensive understanding of paragraphs. In
this crowdsourced, adversarially-created, 96k question-answering benchmark, a this crowdsourced, adversarially-created, 96k question-answering benchmark, a
system must resolve multiple references in a question, map them onto a paragraph, system must resolve multiple references in a question, map them onto a paragraph,
and perform discrete operations over them (such as addition, counting, or sorting). and perform discrete operations over them (such as addition, counting, or sorting).
...@@ -24,7 +24,7 @@ from lm_eval.metrics import mean ...@@ -24,7 +24,7 @@ from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@misc{dua2019drop, @misc{dua2019drop,
title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs}, title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs},
author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner}, author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
year={2019}, year={2019},
eprint={1903.00161}, eprint={1903.00161},
...@@ -70,21 +70,26 @@ class DROP(Task): ...@@ -70,21 +70,26 @@ class DROP(Task):
@classmethod @classmethod
def get_answers(cls, qa): def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers): def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers. """Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...} {"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}] -> [{"number": ['1'], ...}, {"number": ['8'], ...}]
""" """
vas = [] valid_answers = []
for i in range(len(validated_answers["number"])): for i in range(len(validated_answers["number"])):
vas.append({ valid_answers.append(
"number": validated_answers["number"][i], {
"date": validated_answers["date"][i], "number": validated_answers["number"][i],
"spans": validated_answers["spans"][i], "date": validated_answers["date"][i],
}) "spans": validated_answers["spans"][i],
return vas }
)
return valid_answers
answers = [] answers = []
answers_set = set() answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"]) candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates: for candidate in candidates:
answer = cls.parse_answer(candidate) answer = cls.parse_answer(candidate)
if answer in answers_set: if answer in answers_set:
...@@ -100,13 +105,21 @@ class DROP(Task): ...@@ -100,13 +105,21 @@ class DROP(Task):
return (str(answer["number"]),) return (str(answer["number"]),)
if answer["spans"] != []: if answer["spans"] != []:
return tuple(answer["spans"]) return tuple(answer["spans"])
return (" ".join([answer["date"]["day"], return (
answer["date"]["month"], " ".join(
answer["date"]["year"]]).strip(),) [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"] + " " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0]) return " " + ", ".join(doc["answers"][0])
...@@ -142,10 +155,7 @@ class DROP(Task): ...@@ -142,10 +155,7 @@ class DROP(Task):
if gold_answer[0].strip(): if gold_answer[0].strip():
max_em = max(max_em, exact_match) max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score) max_f1 = max(max_f1, f1_score)
return { return {"em": max_em, "f1": max_f1}
"em": max_em,
"f1": max_f1
}
def get_metrics(self, predicted, gold): def get_metrics(self, predicted, gold):
""" """
...@@ -158,7 +168,9 @@ class DROP(Task): ...@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted) predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold) gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0 exact_match = 1.0
else: else:
exact_match = 0.0 exact_match = 0.0
...@@ -190,7 +202,9 @@ class DROP(Task): ...@@ -190,7 +202,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold): for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted): for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item): if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item) scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))]) max_scores = np.zeros([max(len(gold), len(predicted))])
...@@ -256,7 +270,11 @@ class DROP(Task): ...@@ -256,7 +270,11 @@ class DROP(Task):
def _normalize(self, answer): def _normalize(self, answer):
tokens = [ tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower())))) self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer) for token in self._tokenize(answer)
] ]
tokens = [token for token in tokens if token.strip()] tokens = [token for token in tokens if token.strip()]
...@@ -269,10 +287,7 @@ class DROP(Task): ...@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"em": mean, "f1": mean}
"em": mean,
"f1": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -280,7 +295,4 @@ class DROP(Task): ...@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"em": True, "f1": True}
"em": True,
"f1": True
}
...@@ -68,7 +68,15 @@ class CoLA(Task): ...@@ -68,7 +68,15 @@ class CoLA(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(
doc["sentence"]
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format({1: "yes", 0: "no"}[doc["label"]]) return " {}".format({1: "yes", 0: "no"}[doc["label"]])
...@@ -82,19 +90,13 @@ class CoLA(Task): ...@@ -82,19 +90,13 @@ class CoLA(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_true > ll_false pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
return { return {"mcc": (gold, pred)}
"mcc": (gold, pred)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"mcc": True}
"mcc": True
}
def aggregation(self): def aggregation(self):
return { return {"mcc": matthews_corrcoef}
"mcc": matthews_corrcoef
}
class SST(Task): class SST(Task):
...@@ -136,19 +138,13 @@ class SST(Task): ...@@ -136,19 +138,13 @@ class SST(Task):
ll_positive, ll_negative = results ll_positive, ll_negative = results
pred = ll_positive > ll_negative pred = ll_positive > ll_negative
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Inference Tasks # Inference Tasks
...@@ -184,7 +180,8 @@ class MNLI(Task): ...@@ -184,7 +180,8 @@ class MNLI(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'), doc["hypothesis"].strip()
+ ("" if doc["hypothesis"].strip().endswith(".") else "."),
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -202,19 +199,13 @@ class MNLI(Task): ...@@ -202,19 +199,13 @@ class MNLI(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class MNLIMismatched(MNLI): class MNLIMismatched(MNLI):
...@@ -252,9 +243,11 @@ class QNLI(Task): ...@@ -252,9 +243,11 @@ class QNLI(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( return (
doc["question"], "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["sentence"], doc["question"],
doc["sentence"],
)
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -271,19 +264,13 @@ class QNLI(Task): ...@@ -271,19 +264,13 @@ class QNLI(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_no > ll_yes pred = ll_no > ll_yes
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class WNLI(Task): class WNLI(Task):
...@@ -328,19 +315,13 @@ class WNLI(Task): ...@@ -328,19 +315,13 @@ class WNLI(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_true > ll_false pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class RTE(Task): class RTE(Task):
...@@ -385,19 +366,13 @@ class RTE(Task): ...@@ -385,19 +366,13 @@ class RTE(Task):
ll_true, ll_false = results ll_true, ll_false = results
pred = ll_false > ll_true pred = ll_false > ll_true
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
...@@ -449,16 +424,10 @@ class MRPC(Task): ...@@ -449,16 +424,10 @@ class MRPC(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class QQP(Task): class QQP(Task):
...@@ -507,16 +476,10 @@ class QQP(Task): ...@@ -507,16 +476,10 @@ class QQP(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class STSB(Task): class STSB(Task):
...@@ -554,22 +517,22 @@ class STSB(Task): ...@@ -554,22 +517,22 @@ class STSB(Task):
return " {}".format(doc["label"]) return " {}".format(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -578,22 +541,22 @@ class STSB(Task): ...@@ -578,22 +541,22 @@ class STSB(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment