Commit 6597c347 authored by researcher2's avatar researcher2 Committed by researcher2
Browse files

Implement decontamination of evals against training set 13 grams.

parent 67e2bf8b
...@@ -128,6 +128,10 @@ class Task(abc.ABC): ...@@ -128,6 +128,10 @@ class Task(abc.ABC):
"""Downloads the task dataset if necessary""" """Downloads the task dataset if necessary"""
pass pass
def should_decontaminate(self):
"""Whether this task supports decontamination against model training set."""
return False
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
...@@ -170,6 +174,10 @@ class Task(abc.ABC): ...@@ -170,6 +174,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
print("Override doc_to_decontamination_query with document specific decontamination query.")
assert(False)
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
pass pass
...@@ -292,6 +300,10 @@ class MultipleChoiceTask(Task): ...@@ -292,6 +300,10 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC): class PerplexityTask(Task, abc.ABC):
def should_decontaminate(self):
"""Whether this task supports decontamination against model training set."""
return True
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -314,6 +326,9 @@ class PerplexityTask(Task, abc.ABC): ...@@ -314,6 +326,9 @@ class PerplexityTask(Task, abc.ABC):
"bits_per_byte": False, "bits_per_byte": False,
} }
def doc_to_decontamination_query(self, doc):
return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "" return ""
......
...@@ -5,13 +5,16 @@ import lm_eval.metrics ...@@ -5,13 +5,16 @@ import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
from scripts.clean_training_data.contamination import get_train_overlap
import numpy as np import numpy as np
def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None, no_cache=False, limit=None, bootstrap_iters=100000): def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, decontaminate=False,
ngrams_path=None, ngrams_n_size=None):
random.seed(1234) random.seed(1234)
np.random.seed(1234) np.random.seed(1234)
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { lm = lm_eval.models.MODEL_REGISTRY[model].create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device 'batch_size': batch_size, 'device': device
}) })
...@@ -19,7 +22,8 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non ...@@ -19,7 +22,8 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non
lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db') lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
task_dict = lm_eval.tasks.get_task_dict(task_names) task_dict = lm_eval.tasks.get_task_dict(task_names)
results = evaluate(lm, task_dict, False, num_fewshot, limit) results = evaluate(lm, task_dict, False, num_fewshot, limit, bootstrap_iters=bootstrap_iters,
decontaminate=decontaminate, ngrams_path=ngrams_path, ngrams_n_size=ngrams_n_size)
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
...@@ -35,10 +39,15 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non ...@@ -35,10 +39,15 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non
return results return results
decontaminate_suffix = "_decontaminate"
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000): def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000,
decontaminate=False, ngrams_path=None, ngrams_n_size=None):
assert not provide_description # not implemented. todo: implement proper description-providing system assert not provide_description # not implemented. todo: implement proper description-providing system
if decontaminate:
assert ngrams_path and ngrams_n_size
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())] task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
...@@ -49,6 +58,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -49,6 +58,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory, # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have
...@@ -57,6 +68,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -57,6 +68,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of requeste # get lists of each type of requeste
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -64,7 +77,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -64,7 +77,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
task_doc_func = task.test_docs task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs(): elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs task_doc_func = task.validation_docs
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
...@@ -74,6 +89,10 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -74,6 +89,10 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
...@@ -91,6 +110,11 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -91,6 +110,11 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.type].append((i, task_name, doc, doc_id)) requests_origin[req.type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, ngrams_path, ngrams_n_size, limit)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -122,14 +146,22 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -122,14 +146,22 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters) stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[real_metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters)
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
......
...@@ -19,8 +19,13 @@ class GPT2LM(LM): ...@@ -19,8 +19,13 @@ class GPT2LM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
if device: if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self.device = torch.device(device) self.device = torch.device(device)
print(f"Using device '{device}'")
else: else:
print("Device not specificed")
print(f"Cuda Available? {torch.cuda.is_available()}")
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
......
...@@ -50,6 +50,7 @@ class LogiQA(MultipleChoiceTask): ...@@ -50,6 +50,7 @@ class LogiQA(MultipleChoiceTask):
return prompt return prompt
choices = ['a', 'b', 'c', 'd'] choices = ['a', 'b', 'c', 'd']
return { return {
"passage": doc["passage"], # Used for decontamination
"query": format_example(doc, choices), "query": format_example(doc, choices),
"choices": doc["options"], "choices": doc["options"],
"gold": choices.index(doc["answerKey"]) "gold": choices.index(doc["answerKey"])
...@@ -86,3 +87,9 @@ class LogiQA(MultipleChoiceTask): ...@@ -86,3 +87,9 @@ class LogiQA(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"]
...@@ -64,3 +64,9 @@ class SciQ(MultipleChoiceTask): ...@@ -64,3 +64,9 @@ class SciQ(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip() return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
\ No newline at end of file
import argparse import argparse
import json import json
import numpy as np
import random
import logging import logging
import fnmatch
from lm_eval import models, tasks, evaluator, base from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="") parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))
parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
...@@ -20,26 +40,73 @@ def parse_args(): ...@@ -20,26 +40,73 @@ def parse_args():
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--ngrams_path', default=None)
parser.add_argument('--ngrams_n_size', type=int, default=None)
return parser.parse_args() return parser.parse_args()
def ensure_correct_decontamination_params(args):
valid = True
if args.decontaminate:
if not args.ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False
if not args.ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False
return valid
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
def main(): def main():
args = parse_args() args = parse_args()
if not ensure_correct_decontamination_params(args):
return
assert not args.provide_description # not implemented # assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.tasks == "all_tasks": if args.task_type:
task_types = args.task_type.split(",")
task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
tasks.TASK_REGISTRY.items())
).keys())
if args.tasks is None:
if args.task_type is None:
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
results = evaluator.simple_evaluate(args.model, args.model_args, task_names, args.num_fewshot, args.batch_size, args.device, args.no_cache, args.limit) if args.exclude_tasks:
exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
dumped = json.dumps(results, indent=2) if len(task_names) == 0:
print("You must have excluded the tasks you specified, exiting.")
return
print(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(args.model, args.model_args, task_names,
num_fewshot=args.num_fewshot, batch_size=args.batch_size,
device=args.device, no_cache=args.no_cache, limit=args.limit,
decontaminate=args.decontaminate, ngrams_path=args.ngrams_path,
ngrams_n_size=args.ngrams_n_size)
dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
...@@ -4,6 +4,8 @@ import json ...@@ -4,6 +4,8 @@ import json
import jsonlines import jsonlines
import io import io
import datetime import datetime
import mmap
import tqdm
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
...@@ -79,12 +81,63 @@ class TextReader: ...@@ -79,12 +81,63 @@ class TextReader:
def __init__(self, file_path): def __init__(self, file_path):
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh: with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
while True: while True:
line = self.fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
break break
else : else:
yield line[:-1] yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment