Commit baa8b0d3 authored by bzantium's avatar bzantium
Browse files

fix for merge from master

parent a956bc63
...@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder): ...@@ -64,8 +64,9 @@ class Unscramble(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.1") VERSION = datasets.Version("0.0.1")
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
datasets.BuilderConfig(name=name, version=version, datasets.BuilderConfig(
description=_DESCRIPTIONS[name]) name=name, version=version, description=_DESCRIPTIONS[name]
)
for name, version in zip(_NAMES, [VERSION] * len(_NAMES)) for name, version in zip(_NAMES, [VERSION] * len(_NAMES))
] ]
......
{"wikitext-103-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-103-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1281262, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 539297488, "num_examples": 29444, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1142488, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip": {"num_bytes": 190229076, "checksum": "242ba0f20b329cfdf1ccc61e9e9e5b59becf189db7f7a81cd2a0e2fc31539590"}}, "download_size": 190229076, "post_processing_size": null, "dataset_size": 541721238, "size_in_bytes": 731950314}, "wikitext-2-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-2-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1256634, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 10799034, "num_examples": 629, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1121860, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip": {"num_bytes": 4475746, "checksum": "92675f1d63015c1c8b51f1656a52d5bdbc33aafa60cc47a218a66e7ee817488c"}}, "download_size": 4475746, "post_processing_size": null, "dataset_size": 13177528, "size_in_bytes": 17653274}, "wikitext-103-raw-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-103-raw-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1290775, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 540656522, "num_examples": 29444, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1147025, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip": {"num_bytes": 191984949, "checksum": "91c00ae287f0d699e18605c84afc9e45c192bc6b7797ff8837e5474655a33794"}}, "download_size": 191984949, "post_processing_size": null, "dataset_size": 543094322, "size_in_bytes": 735079271}, "wikitext-2-raw-v1": {"description": " The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n", "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n", "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/", "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)", "features": {"page": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "wikitext", "config_name": "wikitext-2-raw-v1", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"test": {"name": "test", "num_bytes": 1290775, "num_examples": 62, "dataset_name": "wikitext"}, "train": {"name": "train", "num_bytes": 10942633, "num_examples": 629, "dataset_name": "wikitext"}, "validation": {"name": "validation", "num_bytes": 1147025, "num_examples": 60, "dataset_name": "wikitext"}}, "download_checksums": {"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip": {"num_bytes": 4721645, "checksum": "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11"}}, "download_size": 4721645, "post_processing_size": null, "dataset_size": 13380433, "size_in_bytes": 18102078}}
\ No newline at end of file
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# NOTE: This is a modified version of https://github.com/huggingface/datasets/blob/master/datasets/wikitext/wikitext.py
# that returns Wiki pages instead of Wiki text line-by-line.
"""WikiText Dataset."""
import os
import datasets
_CITATION = """\
@misc{merity2016pointer,
title={Pointer Sentinel Mixture Models},
author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},
year={2016},
eprint={1609.07843},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_DESCRIPTION = """\
The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified
Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike
License.
"""
_HOMEPAGE = "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/"
_LICENSE = "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)"
_DATA_URL = "https://s3.amazonaws.com/research.metamind.io/wikitext"
class WikitextConfig(datasets.BuilderConfig):
"""BuilderConfig for GLUE."""
def __init__(self, data_url, **kwargs):
"""BuilderConfig for Wikitext
Args:
data_url: `string`, url to the dataset (word or raw level)
**kwargs: keyword arguments forwarded to super.
"""
super(WikitextConfig, self).__init__(
version=datasets.Version(
"1.0.0",
),
**kwargs,
)
self.data_url = data_url
class Wikitext(datasets.GeneratorBasedBuilder):
"""TODO(wikitext_103): Short description of my dataset."""
# TODO(wikitext_103): Set up version.
VERSION = datasets.Version("0.1.0")
BUILDER_CONFIGS = [
WikitextConfig(
name="wikitext-103-v1",
data_url=_DATA_URL + "/" + "wikitext-103-v1.zip",
description="Word level dataset. No processing is needed other than replacing newlines with <eos> tokens.",
),
WikitextConfig(
name="wikitext-2-v1",
data_url=_DATA_URL + "/" + "wikitext-2-v1.zip",
description="Word level dataset. No processing is needed other than replacing newlines with <eos> tokens.",
),
WikitextConfig(
name="wikitext-103-raw-v1",
data_url=_DATA_URL + "/" + "wikitext-103-raw-v1.zip",
description="Raw level dataset: the raw tokens before the addition of <unk> tokens. "
"They should only be used for character level work or for creating newly derived datasets.",
),
WikitextConfig(
name="wikitext-2-raw-v1",
data_url=_DATA_URL + "/" + "wikitext-2-raw-v1.zip",
description="Raw level dataset: the raw tokens before the addition of <unk> tokens. "
"They should only be used for character level work or for creating newly derived datasets.",
),
]
def _info(self):
# TODO(wikitext): Specifies the datasets.DatasetInfo object
return datasets.DatasetInfo(
# This is the description that will appear on the datasets page.
description=_DESCRIPTION,
# datasets.features.FeatureConnectors
features=datasets.Features(
{
"page": datasets.Value("string")
# These are the features of your dataset like images, labels ...
}
),
# If there's a common (input, target) tuple from the features,
# specify them here. They'll be used if as_supervised=True in
# builder.as_dataset.
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
# TODO(wikitext): Downloads the data and defines the splits
# dl_manager is a datasets.download.DownloadManager that can be used to
# download and extract URLs
if self.config.name == "wikitext-103-v1":
data_file = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.tokens"), "split": "train"},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.tokens"), "split": "valid"},
),
]
else:
if self.config.name == "wikitext-103-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-103-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
),
]
else:
if self.config.name == "wikitext-2-raw-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2-raw")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.raw"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.train.raw"), "split": "train"},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.valid.raw"), "split": "valid"},
),
]
else:
if self.config.name == "wikitext-2-v1":
data_file = dl_manager.download_and_extract(
self.config.data_url)
data_dir = os.path.join(data_file, "wikitext-2")
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={"data_file": os.path.join(
data_dir, "wiki.test.tokens"), "split": "test"},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.train.tokens"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"data_file": os.path.join(data_dir, "wiki.valid.tokens"),
"split": "valid",
},
),
]
def _generate_examples(self, data_file, split):
"""Yields examples."""
with open(data_file, encoding="utf-8") as f:
key = 0
ret = []
data = f.read().split("\n")
for line in data:
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
page = '\n'.join(ret)
if page.strip():
yield key, {"page": page}
key += 1
ret = []
ret.append(line)
page = '\n'.join(ret)
yield key, {"page": page}
...@@ -4,13 +4,18 @@ import json ...@@ -4,13 +4,18 @@ import json
import jsonlines import jsonlines
import io import io
import datetime import datetime
import mmap
import tqdm
from pathlib import Path
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)): if isinstance(obj, (datetime.datetime,)):
return obj.isoformat() return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj)) raise TypeError("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file. # Modified version of lm_dataformat Archive for single file.
class Archive: class Archive:
...@@ -18,26 +23,32 @@ class Archive: ...@@ -18,26 +23,32 @@ class Archive:
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb') self.fh = open(self.file_path, "wb")
self.cctx = zstandard.ZstdCompressor(level=compression_level) self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh) self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}): def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n') self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
)
+ b"\n"
)
def commit(self): def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME) self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader: class Reader:
def __init__(self): def __init__(self):
pass pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'): def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
with open(file, 'rb') as fh: with open(file, "rb") as fh:
self.fh = fh self.fh = fh
cctx = zstandard.ZstdDecompressor() cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh)) reader = io.BufferedReader(cctx.stream_reader(fh))
...@@ -49,42 +60,102 @@ class Reader: ...@@ -49,42 +60,102 @@ class Reader:
yield ob yield ob
continue continue
text = ob['text'] text = ob["text"]
if autojoin_paragraphs and isinstance(text, list): if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text) text = para_joiner.join(text)
if get_meta: if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {}) yield text, (ob["meta"] if "meta" in ob else {})
else: else:
yield text yield text
# Simple text reader and writer with same interface as above
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="ab"): def __init__(self, file_path, mode="rb+"):
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
if not os.path.exists(file_path):
def add_data(self, data, meta={}): Path(file_path).touch()
self.fh.write(data.encode('UTF-8') + b'\n')
self.fh = open(self.file_path, mode)
def add_data(self, data):
self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self): def commit(self):
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
class TextReader: class TextReader:
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
unit_scale=1,
) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, "r", encoding="utf8") as fh:
self.fh = fh with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
while True: while True:
line = self.fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
break break
else : else:
yield line[:-1] yield line[:-1]
\ No newline at end of file
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
import time
import random
import pickle
import json
import glob
import os
import collections
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated)
# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
# Build lookup for each dataset first in case we use different task combinations later
print("Building Lookups...")
start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {}
duplicates = {} # (task_name, task_set): set(doc_ids)}
sets_to_decontaminate = len(docs_by_task_set.keys())
for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(
open(overlaps_dump_path, "rb")
)
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = (
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
)
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
lookups[(task_name, task_set)] = pickle.load(
open(task_set_lookup_path, "rb")
)
else:
print(f"{task_set_lookup_path} not available, building...")
lookup = collections.defaultdict(set)
for doc_id, document in enumerate(docs):
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
lookups[(task_name, task_set)] = lookup
elapsed = time.perf_counter() - start
print(f"Building lookups took {elapsed:0.5f} seconds.")
matched_ngrams = []
if sets_to_decontaminate > 0:
print("Merging lookups...")
start = time.perf_counter()
merged_lookup = collections.defaultdict(list)
for (task_name, task_set), lookup in lookups.items():
for ngram, doc_ids in lookup.items():
merged_lookup[ngram].append((task_name, task_set, doc_ids))
elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
print(files)
for file in files:
start = time.perf_counter()
print(f"Scanning {file}")
reader = ZStdTextReader(file)
total_ngrams = 0
unique_ngrams = 0
matching_unique = 0
non_matching_unique = 0
current_ngram = ""
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if (
ngram != current_ngram
): # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
matched_ngrams.append(ngram) # For logging
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for (
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
non_matching_unique += 1
print(f"Total Ngrams: {total_ngrams}")
print(f"Unique Ngrams: {unique_ngrams}")
print(f"Unique Matching: {matching_unique}")
print(f"Unique Non Matching: {non_matching_unique}")
print("Matched ngrams:")
for ngram in matched_ngrams:
print(ngram)
elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")
print(duplicates)
# Dump overlaps separately
for (task_name, task_set), doc_ids in duplicates.items():
overlaps_dump_path = get_overlaps_dump_path(
task_name, task_set, ngrams_n_size, limit
)
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
...@@ -9,8 +9,9 @@ from pprint import pprint ...@@ -9,8 +9,9 @@ from pprint import pprint
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try: try:
import janitor_util import janitor_util
JANITOR_CPP = True JANITOR_CPP = True
except Exception as e: except Exception:
print("WARNING: C++ module could not be loaded. Janitor running in python mode") print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
...@@ -41,6 +42,7 @@ def word_ngrams(s, n): ...@@ -41,6 +42,7 @@ def word_ngrams(s, n):
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs) return (" ".join(ngram) for ngram in ngram_seqs)
# Does character sequences only - combined faster function to play around with later # Does character sequences only - combined faster function to play around with later
# def word_ngrams_indices_combined(sequence, n): # def word_ngrams_indices_combined(sequence, n):
# current_word = "" # current_word = ""
...@@ -70,7 +72,7 @@ def split_indices(s): ...@@ -70,7 +72,7 @@ def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s, n):
...@@ -90,22 +92,27 @@ def word_ngrams_indices(s, n): ...@@ -90,22 +92,27 @@ def word_ngrams_indices(s, n):
# ([word, word, ...], [(start,end), (start,end), ...]), # ([word, word, ...], [(start,end), (start,end), ...]),
# ... # ...
# ) # )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices) ngram_indices_pairs = (
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs) return (
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
for ngram_seq, indices in ngram_indices_pairs
)
class Janitor: class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars? # FIXME delete_chars: Should anything else go here? Special chars?
def __init__( def __init__(
self, self,
ngram_n=13, ngram_n=13,
window_to_remove=200, window_to_remove=200,
too_dirty_cutoff=10, too_dirty_cutoff=10,
minimum_slice_length=200, minimum_slice_length=200,
delete_chars=string.punctuation delete_chars=string.punctuation,
): ):
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
...@@ -121,7 +128,7 @@ class Janitor: ...@@ -121,7 +128,7 @@ class Janitor:
self.translation_table = str.maketrans( self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted self.delete_chars, # These are deleted
) )
############## ##############
...@@ -129,14 +136,13 @@ class Janitor: ...@@ -129,14 +136,13 @@ class Janitor:
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
############## ##############
# Call these :) # Call these :)
############## ##############
...@@ -152,7 +158,7 @@ class Janitor: ...@@ -152,7 +158,7 @@ class Janitor:
def clean(self, dirty_string): def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
if JANITOR_CPP: if JANITOR_CPP:
return self.clean_cpp(dirty_string) return self.clean_cpp(dirty_string)
...@@ -171,11 +177,11 @@ class Janitor: ...@@ -171,11 +177,11 @@ class Janitor:
end = min(len(dirty_string), end + self.window_to_remove) end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length: if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start]) clean_chunks.append(dirty_string[splice_idx:start])
splice_idx = end splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length: if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:]) clean_chunks.append(dirty_string[end + 1 :])
return clean_chunks return clean_chunks
...@@ -184,10 +190,14 @@ class Janitor: ...@@ -184,10 +190,14 @@ class Janitor:
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)) self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n) contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
return self._split_chunks(dirty_string, contamination_indices) return self._split_chunks(dirty_string, contamination_indices)
############## ##############
...@@ -198,7 +208,9 @@ class Janitor: ...@@ -198,7 +208,9 @@ class Janitor:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n)) self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string): def clean_python(self, dirty_string):
contamination_indices = ( contamination_indices = (
...@@ -249,29 +261,29 @@ class Janitor: ...@@ -249,29 +261,29 @@ class Janitor:
# data = f.read() # data = f.read()
# jan = Janitor(too_dirty_cutoff=1000) # jan = Janitor(too_dirty_cutoff=1000)
# jan.register_contaminant(''' # jan.register_contaminant('''
# theories is that there is a connection between &quot;geekdom&quot; and autism. # theories is that there is a connection between &quot;geekdom&quot; and autism.
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot; # This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
# The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights # The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of # movement{{ref|Wired}}. This article, many professionals assert, is just one example of
# the media's application of mental disease labels to what is actually variant normal behavior # the media's application of mental disease labels to what is actually variant normal behavior
# &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual # &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
# interests, even when they seem unusual to others, are not in themselves signs of autism or # interests, even when they seem unusual to others, are not in themselves signs of autism or
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying # Asperger's syndrome. Others assert that it is actually the medical profession which is applying
# mental disease labels to children who in the past would have simply been accepted as a little # mental disease labels to children who in the past would have simply been accepted as a little
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue. # different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
# Due to the recent publicity surrounding autism and autis # Due to the recent publicity surrounding autism and autis
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
# would last, took a cautious approach, prefering to save the revenue rather than investing it in # would last, took a cautious approach, preferring to save the revenue rather than investing it in
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M, # with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995), # ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the # ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the # Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
# [[United Arab Emirates]]. After the Emirates gained independence in 1971, # [[United Arab Emirates]]. After the Emirates gained independence in 1971,
# ''') # ''')
# """ # """
......
import collections import collections
import itertools import itertools
import pathlib import numpy as np
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import numpy as np
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(
num_fewshot=0, batch_size=None, device=None, model,
no_cache=False, limit=None, bootstrap_iters=100000, model_args=None,
description_dict=None, check_integrity=False): tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str] :param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string. String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object. Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]] :param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...@@ -37,7 +47,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -37,7 +47,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:return :return
...@@ -49,19 +59,25 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -49,19 +59,25 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
if isinstance(model, str): if isinstance(model, str):
if model_args is None: model_args = "" if model_args is None:
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { model_args = ""
'batch_size': batch_size, 'device': device lm = lm_eval.models.get_model(model).create_from_arg_string(
}) model_args, {"batch_size": batch_size, "device": device}
)
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
if check_integrity: if check_integrity:
...@@ -72,7 +88,9 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -72,7 +88,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -85,14 +103,26 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -85,14 +103,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict "description_dict": description_dict,
} }
return results return results
decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -108,7 +138,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -108,7 +138,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -118,12 +148,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -118,12 +148,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented. assert not provide_description # not implemented.
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
decontaminate = decontamination_ngrams_path is not None
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs()) if (task.has_validation_docs() or task.has_test_docs())
] ]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -132,6 +166,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -132,6 +166,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
...@@ -140,6 +176,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -140,6 +176,8 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -147,7 +185,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -147,7 +185,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
task_doc_func = task.test_docs task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs(): elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs task_doc_func = task.validation_docs
else: else:
raise RuntimeError("Task has neither test_docs nor validation_docs") raise RuntimeError("Task has neither test_docs nor validation_docs")
...@@ -158,15 +198,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -158,15 +198,22 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
)
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
...@@ -177,6 +224,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -177,6 +224,15 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(
docs_for_decontamination, decontamination_ngrams_path, limit
)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -189,11 +245,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -189,11 +245,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
...@@ -207,25 +265,36 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -207,25 +265,36 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(
decontaminate_suffix, ""
) # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
return { return {"results": dict(results), "versions": dict(versions)}
"results": dict(results),
"versions": dict(versions)
}
def make_table(result_dict): def make_table(result_dict):
...@@ -247,9 +316,9 @@ def make_table(result_dict): ...@@ -247,9 +316,9 @@ def make_table(result_dict):
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append([k, version, m, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
...@@ -110,6 +110,7 @@ def weighted_mean(items): ...@@ -110,6 +110,7 @@ def weighted_mean(items):
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
def bits_per_byte(items): def bits_per_byte(items):
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
...@@ -191,8 +192,10 @@ def _sacreformat(refs, preds): ...@@ -191,8 +192,10 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
...@@ -210,9 +213,10 @@ class _bootstrap_internal: ...@@ -210,9 +213,10 @@ class _bootstrap_internal:
def bootstrap_stderr(f, xs, iters): def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp import multiprocessing as mp
pool = mp.Pool(mp.cpu_count()) pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev. # equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is # Unfortunately, I haven't been able to figure out what the right correction is
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
...@@ -220,10 +224,15 @@ def bootstrap_stderr(f, xs, iters): ...@@ -220,10 +224,15 @@ def bootstrap_stderr(f, xs, iters):
res = [] res = []
chunk_size = min(1000, iters) chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap( for bootstrap in tqdm(
pool.imap(
_bootstrap_internal(f, chunk_size), _bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size): [(i, xs) for i in range(iters // chunk_size)],
),
total=iters // chunk_size,
):
# sample w replacement # sample w replacement
res.extend(bootstrap) res.extend(bootstrap)
...@@ -246,17 +255,13 @@ def stderr_for_metric(metric, bootstrap_iters): ...@@ -246,17 +255,13 @@ def stderr_for_metric(metric, bootstrap_iters):
if metric in bootstrappable: if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = { stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None) return stderr.get(metric, None)
def yesno(x): def yesno(x):
if x: if x:
return 'yes' return "yes"
else: else:
return 'no' return "no"
from . import gpt2 from . import gpt2
from . import gpt3 from . import gpt3
from . import huggingface
from . import textsynth
from . import dummy from . import dummy
MODEL_REGISTRY = { MODEL_REGISTRY = {
"hf": gpt2.HFLM, "hf": gpt2.HFLM,
"hf-causal": gpt2.HFLM,
"hf-causal-experimental": huggingface.AutoCausalLM,
"hf-seq2seq": huggingface.AutoSeq2SeqLM,
"gpt2": gpt2.GPT2LM, "gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM, "gpt3": gpt3.GPT3LM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM, "dummy": dummy.DummyLM,
} }
......
import transformers
import torch import torch
import transformers
from typing import Optional
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__(
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): self,
device="cuda",
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
subfolder=None,
tokenizer=None,
batch_size=1,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
):
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(batch_size, int) assert isinstance(batch_size, (int, str))
if device: device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'")
else: else:
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
trust_remote_code=trust_remote_code,
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
<<<<<<< HEAD
revision=revision + ("/" + subfolder if subfolder is not None else "")) revision=revision + ("/" + subfolder if subfolder is not None else ""))
# assert isinstance(self.tokenizer, ( # assert isinstance(self.tokenizer, (
...@@ -46,6 +73,29 @@ class HFLM(BaseLM): ...@@ -46,6 +73,29 @@ class HFLM(BaseLM):
# gpus = torch.cuda.device_count() # gpus = torch.cuda.device_count()
# if gpus > 1: # if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2) # self.gpt2 = nn.DataParallel(self.gpt2)
=======
revision=revision,
trust_remote_code=trust_remote_code,
)
self.vocab_size = self.tokenizer.vocab_size
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# setup for automatic batch size detection
if batch_size == "auto":
self.batch_size_per_gpu = batch_size
else:
self.batch_size_per_gpu = int(batch_size)
>>>>>>> d145167959c2b1826d900524912cb99c44d5fb30
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -92,12 +142,11 @@ class HFLM(BaseLM): ...@@ -92,12 +142,11 @@ class HFLM(BaseLM):
return self.gpt2(inps)[0] return self.gpt2(inps)[0]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( generation_kwargs = {"do_sample": False, "max_length": max_length}
context, if eos_token_id is not None:
max_length=max_length, generation_kwargs['eos_token_id'] = eos_token_id
eos_token_id=eos_token_id, generation_kwargs['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
do_sample=False return self.gpt2.generate(context, **generation_kwargs)
)
# for backwards compatibility # for backwards compatibility
......
...@@ -31,22 +31,24 @@ def get_result(response, ctxlen): ...@@ -31,22 +31,24 @@ def get_result(response, ctxlen):
if top_token != token: if top_token != token:
is_greedy = False is_greedy = False
break break
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
def oa_completion(**kwargs): def oa_completion(**kwargs):
""" Query OpenAI API for completion. """Query OpenAI API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import openai import openai
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
return openai.Completion.create(**kwargs) return openai.Completion.create(**kwargs)
except openai.error.OpenAIError: except openai.error.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM): ...@@ -66,16 +68,19 @@ class GPT3LM(BaseLM):
super().__init__() super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away # to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode("hello\n\nhello") == [31373, 198, 198, 31373]
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0] self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(
["<|endoftext|>"]
)[0]
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
...@@ -105,7 +110,7 @@ class GPT3LM(BaseLM): ...@@ -105,7 +110,7 @@ class GPT3LM(BaseLM):
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
...@@ -118,17 +123,22 @@ class GPT3LM(BaseLM): ...@@ -118,17 +123,22 @@ class GPT3LM(BaseLM):
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm): re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = [] inps = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token # max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length+1):] inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens # TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1)) ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM): ...@@ -137,11 +147,14 @@ class GPT3LM(BaseLM):
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
echo=True, echo=True,
max_tokens=0, temperature=0., max_tokens=0,
temperature=0.0,
logprobs=10, logprobs=10,
) )
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk): for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen) answer = get_result(resp, ctxlen)
res.append(answer) res.append(answer)
...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM): ...@@ -150,7 +163,7 @@ class GPT3LM(BaseLM):
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: if not requests:
...@@ -160,8 +173,8 @@ class GPT3LM(BaseLM): ...@@ -160,8 +173,8 @@ class GPT3LM(BaseLM):
def _collate(x): def _collate(x):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return len(toks), x[0] return len(toks), x[0]
reord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size): def sameuntil_chunks(xs, size):
ret = [] ret = []
...@@ -172,39 +185,41 @@ class GPT3LM(BaseLM): ...@@ -172,39 +185,41 @@ class GPT3LM(BaseLM):
ret = [] ret = []
lastuntil = x[1] lastuntil = x[1]
ret.append(x) ret.append(x)
if ret: if ret:
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
temperature=0., temperature=0.0,
logprobs=10, logprobs=10,
stop=until, stop=until,
) )
for resp, (context, until_) in zip(response.choices, chunk): for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text'] s = resp["text"]
for term in until_: for term in until_:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching # partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s) self.cache_hook.add_partial("greedy_until", (context, until_), s)
res.append(s) res.append(s)
return reord.get_original(res) return re_ord.get_original(res)
def _model_call(self, inps): def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
......
import math
import torch
import torch.nn.functional as F
import transformers
import peft
from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
from transformers import BatchEncoding
from accelerate import find_executable_batch_size
from lm_eval import utils
from lm_eval.base import BaseLM
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
_DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]])
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {}
if max_memory_per_gpu is not None:
max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu
for device_idx in range(torch.cuda.device_count())
}
max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory
args = {}
if max_memory:
args["max_memory"] = max_memory
args["device_map"] = device_map_option
args["offload_folder"] = offload_folder
return args
def _get_dtype(
dtype: Union[str, torch.dtype], config: Optional[transformers.AutoConfig] = None
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible."""
if dtype is None and config is not None:
_torch_dtype = config.torch_dtype
elif isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HuggingFaceAutoLM(BaseLM):
AUTO_CONFIG_CLASS: transformers.AutoConfig = transformers.AutoConfig
AUTO_TOKENIZER_CLASS: transformers.AutoTokenizer = transformers.AutoTokenizer
AUTO_MODEL_CLASS: transformers.AutoModel = None
AUTO_PEFT_CLASS: peft.PeftModel = None
# Default max sequence length setting for when no `max_length` is provided
# or no max length config setting is found in the model or tokenizer.
_DEFAULT_MAX_LENGTH: int = 2048
def __init__(
self,
pretrained: str,
tokenizer: Optional[str] = None,
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: Optional[Union[int, str]] = 1,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
use_accelerate: Optional[bool] = False,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[int, str]] = "cuda",
peft: str = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
pretrained (str):
The HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API.
add_special_tokens (bool, optional, defaults to True):
Whether to add special tokens to the input sequences. If `None`, the
default value will be set to `True` for seq2seq models (e.g. T5) and
`False` for causal models.
WARNING: Evaluating causal models with `add_special_tokens=True` is
currently __not__ supported.
> Large model loading `accelerate` arguments
use_accelerate (bool, optional, defaults to False):
If True, uses the `accelerate` library to load a large model across
multiple devices.
device_map_option (str, optional, defaults to "auto"):
The device map option to use when loading the model with
`accelerate`.
Options:
"auto", "balanced", "balanced_low_0", "sequential"
See the `accelerate` docs for more details on these options:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.device_map
max_memory_per_gpu (Union[int, str], optional, defaults to None):
The maximum memory available for each GPU in bytes as `int` or in
the format f"{significand}{unit_symbol}" where {unit_symbol} is
any of ["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in
the "Parameters for big model inference" section of the following
docs:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory
max_cpu_memory (Union[int, str], optional, defaults to None):
The maximum available CPU RAM in bytes as `int` or in the format
f"{significand}{unit_symbol}" where {unit_symbol} is any of
["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in the
"Parameters for big model inference" section of the following docs:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory
offload_folder (str, optional, defaults to "./offload"):
The folder to offload weights into if `device_map` contains any
"disk" value.
dtype (Union[str, torch.dtype], optional, defaults to None):):
Converts the model weights to `dtype`, if specified. Strings get
converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`).
Use `dtype="auto"` to derive the type from the model’s weights.
peft (str, optional, defaults to None):
Path of the adapter weights to load from Huggingface. This will usually
include a directory that includes the files `adapter_config.json` and
`adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft)
load_in_8bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-8bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit
trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model.
"""
super().__init__()
assert isinstance(pretrained, str)
assert isinstance(device, str)
assert isinstance(batch_size, (int, str))
if (
add_special_tokens is not None
and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM
):
# TODO: Support evaluating causal models with special tokens. Currently,
# this is not possible because the `_loglikelihood_tokens()` method for
# causal LMs makes a no-special-tokens assumption given that contexts
# and labels/continuations are tokenized separately without special
# tokens, concatenated, and then processed as inputs.
assert (
not add_special_tokens
), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
# setup for automatic batch size detection
if batch_size == "auto":
self._batch_size = batch_size
else:
self._batch_size = int(batch_size)
self._max_gen_toks = max_gen_toks
self._max_length = max_length
self._config = self.AUTO_CONFIG_CLASS.from_pretrained(
pretrained,
trust_remote_code=trust_remote_code,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
self._add_special_tokens = add_special_tokens
self.tokenizer = self._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
)
self.tokenizer.model_max_length = self.max_length
model_kwargs = {}
if use_accelerate:
model_kwargs = _get_accelerate_args(
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
offload_folder,
)
model_kwargs["load_in_8bit"] = load_in_8bit
self.model = self._create_auto_model(
pretrained=pretrained,
trust_remote_code=trust_remote_code,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
if peft is not None:
self.model = self._create_auto_model_peft(
model=self.model,
peft=peft,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**model_kwargs,
)
self.model.eval()
torch.set_grad_enabled(False)
self._device = device
if use_accelerate and "lm_head" in self.model.hf_device_map:
# `accelerate` can place `lm_head` weights on a different device than
# the user specified one so we force `self._device` to be the same as
# `lm_head`'s.
self._device = self.model.hf_device_map["lm_head"]
if not use_accelerate:
self.model.to(self._device)
def _create_auto_model(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
return model
def _create_auto_model_peft(
self,
*,
model: transformers.PreTrainedModel,
peft: str,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
):
model = self.AUTO_PEFT_CLASS.from_pretrained(
model,
peft,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
return model
def _create_auto_tokenizer(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
@property
def add_special_tokens(self) -> bool:
"""Whether to include special tokens in encoded text. This should be
determined by whether or not the model was trained with special tokens.
TODO: Remove these conditionals once HuggingFace supports a way to
check whether or not an arbitrary model was trained with special tokens.
"""
if self._add_special_tokens is not None:
return self._add_special_tokens
elif self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM:
return False
elif self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM:
return True
else:
raise ValueError(
"Could not determine `add_special_tokens` value from the model "
"class. Set to `True` or `False` depending on whether the model "
"was pre-trained with special tokens."
)
@property
def eot_token(self) -> str:
return self.tokenizer.eos_token
@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def max_gen_toks(self) -> int:
return self._max_gen_toks
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model.
NOTE: Different model configurations have different max sequence length
attribute names.
- n_positions: (CTRLConfig)
- max_position_embeddings: (BartConfig, RoFormerConfig)
- n_ctx: (GPT2Config)
NOTE: For relative position encoded models you should specify the max
sequence length of the model in the constructor via `max_length`.
"""
if self._max_length is not None:
return self._max_length
# Try to get the sequence length from the model config.
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def batch_size(self) -> int:
# TODO: Add adaptive batch size.
return self._batch_size # * gpus
@property
def device(self) -> Union[int, str, torch.device]:
return self._device
def tok_encode(self, string: str) -> TokenSequence:
# TODO: Merge `tok_encode_batch` here.
return self.tokenizer.encode(string, add_special_tokens=self.add_special_tokens)
def tok_encode_batch(self, strings: List[str]) -> TokenSequence:
return self.tokenizer(
strings,
padding=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
def greedy_until(
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False),
self.batch_size if self.batch_size != "auto" else adaptive_batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
token_context = self.tok_encode_batch(context)
responses = self._model_generate(
inputs=token_context,
max_tokens=max_tokens,
stop=until,
)
responses = self.tok_decode(responses.tolist())
for response in responses:
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
class AutoCausalLM(HuggingFaceAutoLM):
"""Causal language modeling.
You can find a set of supported models in the HF documentation:
https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForCausalLM
"""
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
AUTO_PEFT_CLASS = peft.PeftModel
def _create_auto_tokenizer(
self,
*,
pretrained: str,
revision: str,
subfolder: str,
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
tokenizer = super()._create_auto_tokenizer(
pretrained=pretrained,
revision=revision,
subfolder=subfolder,
tokenizer=tokenizer,
)
tokenizer.padding_side = "left"
return tokenizer
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(inputs)["logits"]
def _model_generate(
self,
inputs: transformers.BatchEncoding,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> TokenSequence:
# Ensure that the context does not encroach into the `space`
# for the generation.
input_ids = inputs["input_ids"][:, self.max_gen_toks - self.max_length :]
attention_mask = inputs["attention_mask"][
:, self.max_gen_toks - self.max_length :
]
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0]
)
generations = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# GPT style models require the `generate` `max_length` arg to include the
# context length, so we instead set `max_new_tokens` which is the number
# of new tokens to generate, excluding the current number of tokens.
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return utils.select_continuation_from_batch_left_padding(
generations, max_context_size=inputs["input_ids"].size(1)
)
class AutoSeq2SeqLM(HuggingFaceAutoLM):
"""Seq2Seq language modeling.
You can find a set of supported models in the following documentation:
https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForSeq2SeqLM
"""
AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
AUTO_PEFT_CLASS = peft.PeftModel
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model.
TODO: Currently only works for relative position encoded Seq2Seq models.
"""
if self._max_length is not None:
return self._max_length
return self._DEFAULT_MAX_LENGTH
def loglikelihood(
self, requests: List[Tuple[str, str]]
) -> List[Tuple[float, bool]]:
new_requests = []
for chunk in utils.chunks(requests, self.batch_size):
context, continuation = zip(*chunk)
# Fill empty contexts with the EOT token.
context = [
f"{self.eot_token}" if len(text) == 0 else text for text in context
]
context_enc = self.tok_encode_batch(context)
for key in context_enc:
context_enc[key] = context_enc[key][:, -self.max_length :]
# Remove leading whitespace introduced by the default
# `text_target_separator` since the context and continuation
# will not be concatenated as a single (decoder) input.
continuation = [text.lstrip() for text in continuation]
continuation_enc = self.tok_encode_batch(list(continuation))
for key in continuation_enc:
continuation_enc[key] = continuation_enc[key][:, -self.max_length :]
new_requests.append(
((context, continuation), context_enc, continuation_enc)
)
return self._loglikelihood_tokens(new_requests)
def loglikelihood_rolling(self, requests: List[Tuple[str, str]]) -> List[float]:
loglikelihoods = []
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
contexts, conts = utils.split_and_pad_windows(
rolling_token_windows,
pad_token_id=self.eot_token_id,
max_seq_len=self.max_length,
)
# Manually create BatchEncoding tensors with attention masks as
# expected by `self._model_call` in `self._loglikelihood_tokens`.
contexts_enc = torch.Tensor(contexts).long()
contexts_enc = transformers.tokenization_utils_base.BatchEncoding(
{
"input_ids": contexts_enc,
"attention_mask": (contexts_enc != self.eot_token_id).long(),
}
)
conts_enc = torch.Tensor(conts).long()
conts_enc = transformers.tokenization_utils_base.BatchEncoding(
{
"input_ids": conts_enc,
"attention_mask": (conts_enc != self.eot_token_id).long(),
}
)
# TODO: Extract out this call so it only gets called once and also
# somehow figure out partial caching for.
rolling_token_windows_request = [
((contexts, conts), contexts_enc, conts_enc)
]
string_nll = self._loglikelihood_tokens(
rolling_token_windows_request, disable_tqdm=True
)
string_nll = [x[0] for x in string_nll] # discard is_greedy
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], TokenSequence, TokenSequence]],
disable_tqdm: Optional[bool] = False,
) -> List[Tuple[float, bool]]:
results = []
for chunk in tqdm(
requests, total=math.ceil(len(requests)), disable=disable_tqdm
):
cache_keys, inputs_tokens, targets_tokens = chunk
inputs_tokens = inputs_tokens.to(self.device)
targets_tokens = targets_tokens.to(self.device)
outputs = self._model_call(inputs=inputs_tokens, labels=targets_tokens)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
zip(cache_keys[0], cache_keys[1]),
log_softmaxes,
targets_tokens["input_ids"],
targets_tokens["attention_mask"],
)
for cache_key, log_softmax, target_tokens, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tokens = target_tokens[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tokens).all()
target_logits = torch.gather(
log_softmax, 1, target_tokens.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
results.append(answer)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return results
def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(**inputs, labels=labels["input_ids"])
def _model_generate(
self,
inputs: transformers.BatchEncoding,
max_tokens: int,
stop: Optional[List[str]] = None,
) -> TokenSequence:
input_ids = inputs["input_ids"][:, -self.max_length :].to(self.device)
attention_mask = inputs["attention_mask"][:, -self.max_length :].to(self.device)
# Generate one token to calculate the number of start tokens prepended to decoder_input_ids
# (leaving this here in case the below assumption is violated in the future)
# one_tok_gen = self.model.generate(
# input_ids=torch.zeros((1, 1), dtype=torch.int),
# min_length=2,
# max_new_tokens=1,
# ).squeeze()
# initial_decoder_input_length = len(one_tok_gen) - 1
# Assume that there will always only be one token in the decoder inputs, assumption holds for existing HF models
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, input_ids.shape[0]
)
generations = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
)
return generations
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
""" TextSynth API
Implementation provided by Fabrice Bellard:
https://github.com/EleutherAI/lm-evaluation-harness/issues/295
In order to use the API, you must have a valid TextSynth account and
enough credits.
Example usage:
python main.py --model textsynth --model_args engine=gptj_6B --no_cache --tasks piqa
Homepage: https://textsynth.com/index.html
"""
import logging
import os
import requests as _requests
import time
from tqdm import tqdm
from lm_eval.base import BaseLM
logger = logging.getLogger(__name__)
def textsynth_completion(**kwargs):
"""Query TextSynth API for completion.
Retry with back-off until they respond.
"""
backoff_time = 3
while True:
try:
return _requests.post(**kwargs)
except _requests.exceptions.RequestException:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
class TextSynthLM(BaseLM):
def __init__(self, engine, truncate=False):
"""
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
self.engine = engine
self.truncate = truncate
self.api_url = "https://api.textsynth.com"
# Read from environment variable TEXTSYNTH_API_SECRET_KEY
self.api_key = os.environ["TEXTSYNTH_API_SECRET_KEY"]
@property
def eot_token_id(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
@property
def max_length(self):
# NOTE: Turn on truncation to avoid errors on long inputs.
return 2048
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
def tok_encode(self, string: str):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
def tok_decode(self, tokens):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
raise NotImplementedError()
def loglikelihood(self, requests):
res = []
for context, continuation in tqdm(requests):
response = textsynth_completion(
url=self.api_url + "/v1/engines/" + self.engine + "/logprob",
headers={"Authorization": "Bearer " + self.api_key},
json={"context": context, "continuation": continuation},
)
resp = response.json()
if "logprob" in resp:
logprob = resp["logprob"]
is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy))
else:
logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}"
)
assert False
return res
def loglikelihood_rolling(self, requests):
# TODO: The TextSynth API does not support tokenized inputs so we cannot
# manually partition long contexts into smaller rolling windows as
# done for other models derived from `BaseLM`. Override this method
# with a windowing scheme that works for direct string inputs.
raise NotImplementedError(
"`loglikelihood_rolling` is currently not supported due to lack of "
"input tokenization support from TextSynth."
)
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = textsynth_completion(
url=self.api_url + "/v1/engines/" + self.engine + "/completions",
headers={"Authorization": "Bearer " + self.api_key},
json={
"prompt": inp,
"max_tokens": self.max_gen_toks,
"top_k": 1,
"stop": until,
},
)
resp = response.json()
if "text" in resp:
s = resp["text"]
res.append(s)
else:
logger.error(
f"The following response does not contain generated `text`. "
"Got:\n{resp}"
)
assert False
return res
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
...@@ -15,20 +15,19 @@ from . import wsc273 ...@@ -15,20 +15,19 @@ from . import wsc273
from . import winogrande from . import winogrande
from . import quac from . import quac
from . import hellaswag from . import hellaswag
from . import swag
from . import openbookqa from . import openbookqa
from . import squad from . import squad
from . import naturalqs from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import race
from . import piqa from . import piqa
from . import prost from . import prost
from . import mc_taco from . import mc_taco
from . import triviaqa from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq from . import sciq
from . import webqs
from . import qasper from . import qasper
from . import qa4mre from . import qa4mre
from . import translation from . import translation
...@@ -58,6 +57,15 @@ from . import ko_translation ...@@ -58,6 +57,15 @@ from . import ko_translation
from . import korquad from . import korquad
from . import korunsmile from . import korunsmile
from . import kohatespeech from . import kohatespeech
from . import toxigen
from . import crowspairs
from . import xcopa
from . import bigbench
from . import xstorycloze
from . import xwinograd
from . import pawsx
from . import xnli
from . import mgsm
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -66,15 +74,15 @@ from . import kohatespeech ...@@ -66,15 +74,15 @@ from . import kohatespeech
# 6 total # 6 total
gpt3_translation_benchmarks = { gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ["en-fr", "fr-en"], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
} }
# 28 total # 28 total
selected_translation_benchmarks = { selected_translation_benchmarks = {
**gpt3_translation_benchmarks, **gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ["en-ar", "ar-en"], # Arabic
} }
# 319 total # 319 total
...@@ -98,7 +106,7 @@ TASK_REGISTRY = { ...@@ -98,7 +106,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet # "stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -109,6 +117,7 @@ TASK_REGISTRY = { ...@@ -109,6 +117,7 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
<<<<<<< HEAD
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
...@@ -116,38 +125,42 @@ TASK_REGISTRY = { ...@@ -116,38 +125,42 @@ TASK_REGISTRY = {
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
=======
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada_openai": lambada.LambadaOpenAI,
"lambada_standard": lambada.LambadaStandard,
"lambada_openai_cloze": lambada_cloze.LambadaOpenAICloze,
"lambada_standard_cloze": lambada_cloze.LambadaStandardCloze,
>>>>>>> d145167959c2b1826d900524912cb99c44d5fb30
# multilingual lambada # multilingual lambada
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST, "prost": prost.PROST,
"mc_taco": mc_taco.MCTACO, "mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq": sciq.SciQ,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2013": qa4mre.QA4MRE_2013,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, "hellaswag": hellaswag.HellaSwag,
"swag": swag.SWAG,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs, "headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn, "headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA, "mathqa": mathqa.MathQA,
...@@ -157,21 +170,17 @@ TASK_REGISTRY = { ...@@ -157,21 +170,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
...@@ -182,7 +191,6 @@ TASK_REGISTRY = { ...@@ -182,7 +191,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv, "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K, "gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
...@@ -196,22 +204,18 @@ TASK_REGISTRY = { ...@@ -196,22 +204,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks) # hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(), **hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20 # chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks), **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks # Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1, "anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2, "anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile # Pile
"pile_arxiv": pile.PileArxiv, "pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3, "pile_books3": pile.PileBooks3,
...@@ -235,7 +239,10 @@ TASK_REGISTRY = { ...@@ -235,7 +239,10 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
<<<<<<< HEAD
=======
>>>>>>> d145167959c2b1826d900524912cb99c44d5fb30
# BLiMP # BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
...@@ -304,10 +311,37 @@ TASK_REGISTRY = { ...@@ -304,10 +311,37 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
<<<<<<< HEAD
=======
"toxigen": toxigen.ToxiGen,
"crows_pairs_english": crowspairs.CrowsPairsEnglish,
"crows_pairs_english_race_color": crowspairs.CrowsPairsEnglishRaceColor,
"crows_pairs_english_socioeconomic": crowspairs.CrowsPairsEnglishSocioeconomic,
"crows_pairs_english_gender": crowspairs.CrowsPairsEnglishGender,
"crows_pairs_english_age": crowspairs.CrowsPairsEnglishAge,
"crows_pairs_english_religion": crowspairs.CrowsPairsEnglishReligion,
"crows_pairs_english_disability": crowspairs.CrowsPairsEnglishDisability,
"crows_pairs_english_sexual_orientation": crowspairs.CrowsPairsEnglishSexualOrientation,
"crows_pairs_english_nationality": crowspairs.CrowsPairsEnglishNationality,
"crows_pairs_english_physical_appearance": crowspairs.CrowsPairsEnglishPhysicalAppearance,
"crows_pairs_english_autre": crowspairs.CrowsPairsEnglishAutre,
"crows_pairs_french": crowspairs.CrowsPairsFrench,
"crows_pairs_french_race_color": crowspairs.CrowsPairsFrenchRaceColor,
"crows_pairs_french_socioeconomic": crowspairs.CrowsPairsFrenchSocioeconomic,
"crows_pairs_french_gender": crowspairs.CrowsPairsFrenchGender,
"crows_pairs_french_age": crowspairs.CrowsPairsFrenchAge,
"crows_pairs_french_religion": crowspairs.CrowsPairsFrenchReligion,
"crows_pairs_french_disability": crowspairs.CrowsPairsFrenchDisability,
"crows_pairs_french_sexual_orientation": crowspairs.CrowsPairsFrenchSexualOrientation,
"crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
"crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
"crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
>>>>>>> d145167959c2b1826d900524912cb99c44d5fb30
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies, # "sat": sat.SATAnalogies,
<<<<<<< HEAD
"klue_sts": klue.STS, "klue_sts": klue.STS,
"klue_ynat": klue.YNAT, "klue_ynat": klue.YNAT,
...@@ -325,6 +359,15 @@ TASK_REGISTRY = { ...@@ -325,6 +359,15 @@ TASK_REGISTRY = {
"kohatespeech":kohatespeech.HateSpeech, "kohatespeech":kohatespeech.HateSpeech,
"kohatespeech_gen_bias":kohatespeech.GenderBias, "kohatespeech_gen_bias":kohatespeech.GenderBias,
"kohatespeech_apeach":kohatespeech.Apeach "kohatespeech_apeach":kohatespeech.Apeach
=======
**xcopa.construct_tasks(),
**bigbench.create_all_tasks(),
**xstorycloze.create_all_tasks(),
**xwinograd.create_all_tasks(),
**pawsx.construct_tasks(),
**xnli.construct_tasks(),
**mgsm.construct_tasks(),
>>>>>>> d145167959c2b1826d900524912cb99c44d5fb30
} }
ALL_TASKS = sorted(list(TASK_REGISTRY)) ALL_TASKS = sorted(list(TASK_REGISTRY))
...@@ -333,7 +376,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY)) ...@@ -333,7 +376,7 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name): def get_task(task_name):
try: try:
return TASK_REGISTRY[task_name] return TASK_REGISTRY[task_name]
except KeyError as e: except KeyError:
print("Available tasks:") print("Available tasks:")
pprint(TASK_REGISTRY) pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}") raise KeyError(f"Missing task {task_name}")
...@@ -345,17 +388,23 @@ def get_task_name_from_object(task_object): ...@@ -345,17 +388,23 @@ def get_task_name_from_object(task_object):
return name return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list
if isinstance(task_name, str)
} }
task_name_from_object_dict = { task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list
if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict} return {**task_name_dict, **task_name_from_object_dict}
...@@ -61,36 +61,47 @@ class ANLIBase(Task): ...@@ -61,36 +61,47 @@ class ANLIBase(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really* # appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did? # want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:' return (
doc["premise"]
+ "\nQuestion: "
+ doc["hypothesis"]
+ " True, False, or Neither?\nAnswer:"
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["premise"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment # True = entailment
# False = contradiction # False = contradiction
# Neither = neutral # Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']] return " " + ["True", "Neither", "False"][doc["label"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
ll_true, _ = rf.loglikelihood(ctx, " True") ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither") ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False") ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false return ll_true, ll_neither, ll_false
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -100,29 +111,23 @@ class ANLIBase(Task): ...@@ -100,29 +111,23 @@ class ANLIBase(Task):
""" """
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
......
...@@ -67,6 +67,12 @@ class ARCEasy(MultipleChoiceTask): ...@@ -67,6 +67,12 @@ class ARCEasy(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["query"]
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
......
...@@ -7,8 +7,6 @@ problem in natural language. ...@@ -7,8 +7,6 @@ problem in natural language.
Homepage: https://github.com/openai/gpt-3/tree/master/data Homepage: https://github.com/openai/gpt-3/tree/master/data
""" """
import inspect
import lm_eval.datasets.arithmetic.arithmetic
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -30,7 +28,7 @@ _CITATION = """ ...@@ -30,7 +28,7 @@ _CITATION = """
class Arithmetic(Task): class Arithmetic(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = inspect.getfile(lm_eval.datasets.arithmetic.arithmetic) DATASET_PATH = "EleutherAI/arithmetic"
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -49,10 +47,16 @@ class Arithmetic(Task): ...@@ -49,10 +47,16 @@ class Arithmetic(Task):
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["context"] return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc["completion"] return doc["completion"]
...@@ -61,10 +65,8 @@ class Arithmetic(Task): ...@@ -61,10 +65,8 @@ class Arithmetic(Task):
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction, = results (is_prediction,) = results
return { return {"acc": is_prediction}
"acc": is_prediction
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -72,9 +74,7 @@ class Arithmetic(Task): ...@@ -72,9 +74,7 @@ class Arithmetic(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
......
...@@ -54,42 +54,41 @@ class Asdiv(Task): ...@@ -54,42 +54,41 @@ class Asdiv(Task):
def test_docs(self): def test_docs(self):
raise NotImplementedError("This dataset has no test docs") raise NotImplementedError("This dataset has no test docs")
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting." assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: add solution-type # TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:' return doc["body"] + "\n" + "Question:" + doc["question"] + "\n" + "Answer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["body"] + " " + doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: add formula # TODO: add formula
answer = doc['answer'].split(' (')[0] answer = doc["answer"].split(" (")[0]
return " " + answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc)) ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy return ll, is_greedy
def process_results(self, doc, results): def process_results(self, doc, results):
ll, is_greedy = results ll, is_greedy = results
return { return {"acc": int(is_greedy)}
'acc': int(is_greedy)
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
'acc': mean
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
'acc': True
}
"""
Tasks missing from BIG-bench-hard:
programmatic - boolean_expressions, web of lies, multistep_arithmetic
"""
import os
import json
import hashlib
import functools
import numpy as np
import re
import importlib.resources
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
_CITATION = """
@misc{srivastava2022imitation,
title={Beyond the Imitation Game: Quantifying and extrapolating the capabilities of language models},
author={Aarohi Srivastava and Abhinav Rastogi and Abhishek Rao and Abu Awal Md Shoeb and Abubakar Abid and Adam Fisch and Adam R. Brown and Adam Santoro and Aditya Gupta and Adrià Garriga-Alonso and Agnieszka Kluska and Aitor Lewkowycz and Akshat Agarwal and Alethea Power and Alex Ray and Alex Warstadt and Alexander W. Kocurek and Ali Safaya and Ali Tazarv and Alice Xiang and Alicia Parrish and Allen Nie and Aman Hussain and Amanda Askell and Amanda Dsouza and Ambrose Slone and Ameet Rahane and Anantharaman S. Iyer and Anders Andreassen and Andrea Madotto and Andrea Santilli and Andreas Stuhlmüller and Andrew Dai and Andrew La and Andrew Lampinen and Andy Zou and Angela Jiang and Angelica Chen and Anh Vuong and Animesh Gupta and Anna Gottardi and Antonio Norelli and Anu Venkatesh and Arash Gholamidavoodi and Arfa Tabassum and Arul Menezes and Arun Kirubarajan and Asher Mullokandov and Ashish Sabharwal and Austin Herrick and Avia Efrat and Aykut Erdem and Ayla Karakaş and B. Ryan Roberts and Bao Sheng Loe and Barret Zoph and Bartłomiej Bojanowski and Batuhan Özyurt and Behnam Hedayatnia and Behnam Neyshabur and Benjamin Inden and Benno Stein and Berk Ekmekci and Bill Yuchen Lin and Blake Howald and Cameron Diao and Cameron Dour and Catherine Stinson and Cedrick Argueta and César Ferri Ramírez and Chandan Singh and Charles Rathkopf and Chenlin Meng and Chitta Baral and Chiyu Wu and Chris Callison-Burch and Chris Waites and Christian Voigt and Christopher D. Manning and Christopher Potts and Cindy Ramirez and Clara E. Rivera and Clemencia Siro and Colin Raffel and Courtney Ashcraft and Cristina Garbacea and Damien Sileo and Dan Garrette and Dan Hendrycks and Dan Kilman and Dan Roth and Daniel Freeman and Daniel Khashabi and Daniel Levy and Daniel Moseguí González and Danielle Perszyk and Danny Hernandez and Danqi Chen and Daphne Ippolito and Dar Gilboa and David Dohan and David Drakard and David Jurgens and Debajyoti Datta and Deep Ganguli and Denis Emelin and Denis Kleyko and Deniz Yuret and Derek Chen and Derek Tam and Dieuwke Hupkes and Diganta Misra and Dilyar Buzan and Dimitri Coelho Mollo and Diyi Yang and Dong-Ho Lee and Ekaterina Shutova and Ekin Dogus Cubuk and Elad Segal and Eleanor Hagerman and Elizabeth Barnes and Elizabeth Donoway and Ellie Pavlick and Emanuele Rodola and Emma Lam and Eric Chu and Eric Tang and Erkut Erdem and Ernie Chang and Ethan A. Chi and Ethan Dyer and Ethan Jerzak and Ethan Kim and Eunice Engefu Manyasi and Evgenii Zheltonozhskii and Fanyue Xia and Fatemeh Siar and Fernando Martínez-Plumed and Francesca Happé and Francois Chollet and Frieda Rong and Gaurav Mishra and Genta Indra Winata and Gerard de Melo and Germán Kruszewski and Giambattista Parascandolo and Giorgio Mariani and Gloria Wang and Gonzalo Jaimovitch-López and Gregor Betz and Guy Gur-Ari and Hana Galijasevic and Hannah Kim and Hannah Rashkin and Hannaneh Hajishirzi and Harsh Mehta and Hayden Bogar and Henry Shevlin and Hinrich Schütze and Hiromu Yakura and Hongming Zhang and Hugh Mee Wong and Ian Ng and Isaac Noble and Jaap Jumelet and Jack Geissinger and Jackson Kernion and Jacob Hilton and Jaehoon Lee and Jaime Fernández Fisac and James B. Simon and James Koppel and James Zheng and James Zou and Jan Kocoń and Jana Thompson and Jared Kaplan and Jarema Radom and Jascha Sohl-Dickstein and Jason Phang and Jason Wei and Jason Yosinski and Jekaterina Novikova and Jelle Bosscher and Jennifer Marsh and Jeremy Kim and Jeroen Taal and Jesse Engel and Jesujoba Alabi and Jiacheng Xu and Jiaming Song and Jillian Tang and Joan Waweru and John Burden and John Miller and John U. Balis and Jonathan Berant and Jörg Frohberg and Jos Rozen and Jose Hernandez-Orallo and Joseph Boudeman and Joseph Jones and Joshua B. Tenenbaum and Joshua S. Rule and Joyce Chua and Kamil Kanclerz and Karen Livescu and Karl Krauth and Karthik Gopalakrishnan and Katerina Ignatyeva and Katja Markert and Kaustubh D. Dhole and Kevin Gimpel and Kevin Omondi and Kory Mathewson and Kristen Chiafullo and Ksenia Shkaruta and Kumar Shridhar and Kyle McDonell and Kyle Richardson and Laria Reynolds and Leo Gao and Li Zhang and Liam Dugan and Lianhui Qin and Lidia Contreras-Ochando and Louis-Philippe Morency and Luca Moschella and Lucas Lam and Lucy Noble and Ludwig Schmidt and Luheng He and Luis Oliveros Colón and Luke Metz and Lütfi Kerem Şenel and Maarten Bosma and Maarten Sap and Maartje ter Hoeve and Maheen Farooqi and Manaal Faruqui and Mantas Mazeika and Marco Baturan and Marco Marelli and Marco Maru and Maria Jose Ramírez Quintana and Marie Tolkiehn and Mario Giulianelli and Martha Lewis and Martin Potthast and Matthew L. Leavitt and Matthias Hagen and Mátyás Schubert and Medina Orduna Baitemirova and Melody Arnaud and Melvin McElrath and Michael A. Yee and Michael Cohen and Michael Gu and Michael Ivanitskiy and Michael Starritt and Michael Strube and Michał Swędrowski and Michele Bevilacqua and Michihiro Yasunaga and Mihir Kale and Mike Cain and Mimee Xu and Mirac Suzgun and Mo Tiwari and Mohit Bansal and Moin Aminnaseri and Mor Geva and Mozhdeh Gheini and Mukund Varma T and Nanyun Peng and Nathan Chi and Nayeon Lee and Neta Gur-Ari Krakover and Nicholas Cameron and Nicholas Roberts and Nick Doiron and Nikita Nangia and Niklas Deckers and Niklas Muennighoff and Nitish Shirish Keskar and Niveditha S. Iyer and Noah Constant and Noah Fiedel and Nuan Wen and Oliver Zhang and Omar Agha and Omar Elbaghdadi and Omer Levy and Owain Evans and Pablo Antonio Moreno Casares and Parth Doshi and Pascale Fung and Paul Pu Liang and Paul Vicol and Pegah Alipoormolabashi and Peiyuan Liao and Percy Liang and Peter Chang and Peter Eckersley and Phu Mon Htut and Pinyu Hwang and Piotr Miłkowski and Piyush Patil and Pouya Pezeshkpour and Priti Oli and Qiaozhu Mei and Qing Lyu and Qinlang Chen and Rabin Banjade and Rachel Etta Rudolph and Raefer Gabriel and Rahel Habacker and Ramón Risco Delgado and Raphaël Millière and Rhythm Garg and Richard Barnes and Rif A. Saurous and Riku Arakawa and Robbe Raymaekers and Robert Frank and Rohan Sikand and Roman Novak and Roman Sitelew and Ronan LeBras and Rosanne Liu and Rowan Jacobs and Rui Zhang and Ruslan Salakhutdinov and Ryan Chi and Ryan Lee and Ryan Stovall and Ryan Teehan and Rylan Yang and Sahib Singh and Saif M. Mohammad and Sajant Anand and Sam Dillavou and Sam Shleifer and Sam Wiseman and Samuel Gruetter and Samuel R. Bowman and Samuel S. Schoenholz and Sanghyun Han and Sanjeev Kwatra and Sarah A. Rous and Sarik Ghazarian and Sayan Ghosh and Sean Casey and Sebastian Bischoff and Sebastian Gehrmann and Sebastian Schuster and Sepideh Sadeghi and Shadi Hamdan and Sharon Zhou and Shashank Srivastava and Sherry Shi and Shikhar Singh and Shima Asaadi and Shixiang Shane Gu and Shubh Pachchigar and Shubham Toshniwal and Shyam Upadhyay and Shyamolima and Debnath and Siamak Shakeri and Simon Thormeyer and Simone Melzi and Siva Reddy and Sneha Priscilla Makini and Soo-Hwan Lee and Spencer Torene and Sriharsha Hatwar and Stanislas Dehaene and Stefan Divic and Stefano Ermon and Stella Biderman and Stephanie Lin and Stephen Prasad and Steven T. Piantadosi and Stuart M. Shieber and Summer Misherghi and Svetlana Kiritchenko and Swaroop Mishra and Tal Linzen and Tal Schuster and Tao Li and Tao Yu and Tariq Ali and Tatsu Hashimoto and Te-Lin Wu and Théo Desbordes and Theodore Rothschild and Thomas Phan and Tianle Wang and Tiberius Nkinyili and Timo Schick and Timofei Kornev and Timothy Telleen-Lawton and Titus Tunduny and Tobias Gerstenberg and Trenton Chang and Trishala Neeraj and Tushar Khot and Tyler Shultz and Uri Shaham and Vedant Misra and Vera Demberg and Victoria Nyamai and Vikas Raunak and Vinay Ramasesh and Vinay Uday Prabhu and Vishakh Padmakumar and Vivek Srikumar and William Fedus and William Saunders and William Zhang and Wout Vossen and Xiang Ren and Xiaoyu Tong and Xinran Zhao and Xinyi Wu and Xudong Shen and Yadollah Yaghoobzadeh and Yair Lakretz and Yangqiu Song and Yasaman Bahri and Yejin Choi and Yichi Yang and Yiding Hao and Yifu Chen and Yonatan Belinkov and Yu Hou and Yufang Hou and Yuntao Bai and Zachary Seid and Zhuoye Zhao and Zijian Wang and Zijie J. Wang and Zirui Wang and Ziyi Wu},
year={2022},
eprint={2206.04615},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_DEFAULT_REGEX = r"[^\.\?\!\;\n]+"
class BigBenchJsonTask(Task):
VERSION = 0
def __init__(self, json_path):
self._random_seed = 42
with open(json_path) as file:
self._task_json = json.load(file)
self._has_multi_choice = "multiple_choice_grade" in self._task_json["metrics"]
self._has_generative = "exact_str_match" in self._task_json["metrics"]
self.output_regex = self._task_json.get("output_regex", None)
self.stop_string = self._task_json.get("stop_string", None)
if self.output_regex is None and self.stop_string is None:
self.output_regex = _DEFAULT_REGEX
# differs from the default 30 when evaluating HF models in the BIG-bench codebase
self.max_length = 128
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def test_docs(self):
return _get_unique_examples(self._task_json["examples"])
def doc_to_text(self, doc):
example_input_prefix = self._task_json.get("example_input_prefix", "\nQ: ")
res = f"{example_input_prefix}{doc['input']}"
rng = np.random.RandomState(seed=self._random_seed)
choice_prefix = self._task_json.get("choice_prefix", "\n choice: ")
append_choices = self._task_json.get("append_choices_to_input", True)
if "target_scores" in doc and append_choices:
choice_dict = doc["target_scores"]
permuted_choices = rng.permutation(sorted(list(choice_dict.keys())))
res = f"{res}{choice_prefix}{choice_prefix.join(permuted_choices)}"
example_output_prefix = self._task_json.get("example_output_prefix", "\nA: ")
res = f"{res}{example_output_prefix}"
return res
def doc_to_target(self, doc):
return max(doc["target_scores"].items(), key=lambda x: x[1])[0]
def _doc_to_queries(self, doc):
if "target_scores" in doc:
return list(doc["target_scores"].keys())
return doc["target"] if isinstance(doc["target"], list) else [doc["target"]]
def construct_requests(self, doc, ctx):
requests = []
if self._has_multi_choice:
queries = self._doc_to_queries(doc)
requests += [
rf.loglikelihood(ctx, continuation)[0] for continuation in queries
]
if self._has_generative:
requests.append(
rf.greedy_until(ctx, {"until": [], "max_length": self.max_length})
)
return requests
def process_results(self, doc, results):
res = {}
for metric in self._task_json["metrics"]:
if metric == "multiple_choice_grade":
likelihoods = results[:-1] if self._has_generative else results
queries = self._doc_to_queries(doc)
highest_score_index = _argmax(likelihoods)
highest_score_key = queries[highest_score_index]
res["multiple_choice_grade"] = doc["target_scores"][highest_score_key]
elif metric == "exact_str_match":
postprocessed = _postprocess_output(
results[-1],
max_length=self.max_length,
stop_string=self.stop_string,
output_regex=self.output_regex,
)
res["exact_str_match"] = int(postprocessed == doc["target"])
else:
raise NotImplementedError(f"Metric {metric} isn't implemented")
return res
def aggregation(self):
return {
"multiple_choice_grade": mean,
"exact_str_match": mean,
}
def higher_is_better(self):
return {
"multiple_choice_grade": True,
"exact_str_match": True,
}
@functools.lru_cache()
def _doc_to_few_shot_context(self, shots):
rng = np.random.RandomState(seed=self._random_seed)
res = {}
samples = self.test_docs()
separator = self._task_json.get("few_shot_example_separator", "\n")
for sample in rng.choice(samples, len(samples), replace=False):
valid_samples = [x for x in samples if x != sample]
shot_examples = list(rng.choice(valid_samples, shots, replace=False))
if self._has_multi_choice:
context = separator.join(
[
self.doc_to_text(example)
+ rng.choice(_get_valid_answers(example["target_scores"]))
for example in shot_examples
]
)
else:
context = separator.join(
[
self.doc_to_text(example) + example["target"]
for example in shot_examples
]
)
res[json.dumps(sample)] = context + separator + self.doc_to_text(sample)
return res
def fewshot_context(self, doc, num_fewshot, **kwargs):
if num_fewshot == 0:
res = self.doc_to_text(doc)
else:
res = self._doc_to_few_shot_context(shots=num_fewshot)[json.dumps(doc)]
res = f"{self._task_json.get('task_prefix', '')}{res}"
return res
def _get_valid_answers(scores):
max_value = max(scores.values())
return [key for key, value in scores.items() if value == max_value]
def _get_unique_examples(examples):
seen_examples, res = set(), []
for example in examples:
example_string = json.dumps(example)
if example_string not in seen_examples:
res.append(example)
seen_examples.add(example_string)
return res
def _argmax(array):
"""argmax with deterministic pseudorandom tie breaking."""
max_indices = np.arange(len(array))[array == np.max(array)]
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(
max_indices
)
return max_indices[idx]
def _postprocess_output(text, max_length, stop_string, output_regex):
if isinstance(text, list):
return [
_postprocess_output(mo, max_length, stop_string, output_regex)
for mo in text
]
# Ensure it is a string (will convert from bytes, ... as needed)
if not isinstance(text, str):
text = str(text, "utf-8")
# truncate at max_length
if max_length:
text = text[:max_length]
# Remove all text after any stop_string
if stop_string:
index = text.find(stop_string)
if index > 0:
text = text[: index + len(stop_string)]
# extract substring matching regex (empty string for no match)
if output_regex:
_text = text
text = next(iter(re.findall(output_regex, text)), "")
assert (
not type(text) is tuple
), f'Regex {output_regex} returned multiple matching groups when applied to string {_text}. Try using non-capturing groups, by starting regex groups with ?: (e.g. "(stuff)" -> "(?:stuff)").'
return text
def create_task_from_path(json_path):
class WrappedTask(BigBenchJsonTask):
def __init__(self):
super().__init__(json_path)
return WrappedTask
def create_all_tasks():
resources_dir = importlib.resources.files("lm_eval.datasets") / "bigbench_resources"
supported_tasks = [os.path.splitext(x)[0] for x in os.listdir(resources_dir)]
res = {}
for task_name in supported_tasks:
task_path = os.path.join(resources_dir, f"{task_name}.json")
res[f"bigbench_{task_name}"] = create_task_from_path(task_path)
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment