Commit 6597c347 authored by researcher2's avatar researcher2 Committed by researcher2
Browse files

Implement decontamination of evals against training set 13 grams.

parent 67e2bf8b
......@@ -128,6 +128,10 @@ class Task(abc.ABC):
"""Downloads the task dataset if necessary"""
pass
def should_decontaminate(self):
"""Whether this task supports decontamination against model training set."""
return False
@abc.abstractmethod
def has_training_docs(self):
"""Whether the task has a training set"""
......@@ -170,6 +174,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
print("Override doc_to_decontamination_query with document specific decontamination query.")
assert(False)
@abc.abstractmethod
def doc_to_text(self, doc):
pass
......@@ -292,6 +300,10 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC):
def should_decontaminate(self):
"""Whether this task supports decontamination against model training set."""
return True
def has_training_docs(self):
return False
......@@ -314,6 +326,9 @@ class PerplexityTask(Task, abc.ABC):
"bits_per_byte": False,
}
def doc_to_decontamination_query(self, doc):
return doc
def doc_to_text(self, doc):
return ""
......
......@@ -5,13 +5,16 @@ import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
from scripts.clean_training_data.contamination import get_train_overlap
import numpy as np
def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None, no_cache=False, limit=None, bootstrap_iters=100000):
def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, decontaminate=False,
ngrams_path=None, ngrams_n_size=None):
random.seed(1234)
np.random.seed(1234)
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
lm = lm_eval.models.MODEL_REGISTRY[model].create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
......@@ -19,7 +22,8 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non
lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
task_dict = lm_eval.tasks.get_task_dict(task_names)
results = evaluate(lm, task_dict, False, num_fewshot, limit)
results = evaluate(lm, task_dict, False, num_fewshot, limit, bootstrap_iters=bootstrap_iters,
decontaminate=decontaminate, ngrams_path=ngrams_path, ngrams_n_size=ngrams_n_size)
# add info about the model and few shot config
results["config"] = {
......@@ -35,10 +39,15 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non
return results
decontaminate_suffix = "_decontaminate"
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000,
decontaminate=False, ngrams_path=None, ngrams_n_size=None):
assert not provide_description # not implemented. todo: implement proper description-providing system
if decontaminate:
assert ngrams_path and ngrams_n_size
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
......@@ -49,6 +58,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list)
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have
......@@ -57,6 +68,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
docs = {}
docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of requeste
for task_name, task in task_dict_items:
versions[task_name] = task.VERSION
......@@ -64,7 +77,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs():
task_doc_func = task.test_docs
task_set = "test" # Required for caching in the decontamination
elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
......@@ -74,6 +89,10 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
rnd.shuffle(task_docs)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
......@@ -84,13 +103,18 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs]
if not isinstance(reqs, (list, tuple)): reqs = [reqs]
for i, req in enumerate(reqs):
requests[req.type].append(req)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, ngrams_path, ngrams_n_size, limit)
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
......@@ -121,15 +145,23 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
metrics = task.process_results(doc, requests)
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
vals[(task_name, metric + decontaminate_suffix)].append(value)
# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
real_metric = metric # key when looking up the metric with task.aggregation
if metric.endswith(decontaminate_suffix):
real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
results[task_name][metric] = task.aggregation()[real_metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters)
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[real_metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
......
......@@ -18,9 +18,14 @@ class GPT2LM(LM):
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
if device:
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self.device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specificed")
print(f"Cuda Available? {torch.cuda.is_available()}")
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF
......
......@@ -50,6 +50,7 @@ class LogiQA(MultipleChoiceTask):
return prompt
choices = ['a', 'b', 'c', 'd']
return {
"passage": doc["passage"], # Used for decontamination
"query": format_example(doc, choices),
"choices": doc["options"],
"gold": choices.index(doc["answerKey"])
......@@ -86,3 +87,9 @@ class LogiQA(MultipleChoiceTask):
def doc_to_text(self, doc):
return doc["query"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["passage"]
......@@ -63,4 +63,10 @@ class SciQ(MultipleChoiceTask):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
\ No newline at end of file
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
\ No newline at end of file
import argparse
import json
import numpy as np
import random
import logging
import fnmatch
from lm_eval import models, tasks, evaluator, base
from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))
parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
......@@ -20,26 +40,73 @@ def parse_args():
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--ngrams_path', default=None)
parser.add_argument('--ngrams_n_size', type=int, default=None)
return parser.parse_args()
def ensure_correct_decontamination_params(args):
valid = True
if args.decontaminate:
if not args.ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False
if not args.ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False
return valid
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
def main():
args = parse_args()
if not ensure_correct_decontamination_params(args):
return
assert not args.provide_description # not implemented
# assert not args.provide_description # not implemented
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
if args.task_type:
task_types = args.task_type.split(",")
task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
tasks.TASK_REGISTRY.items())
).keys())
if args.tasks is None:
if args.task_type is None:
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
results = evaluator.simple_evaluate(args.model, args.model_args, task_names, args.num_fewshot, args.batch_size, args.device, args.no_cache, args.limit)
if args.exclude_tasks:
exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
dumped = json.dumps(results, indent=2)
if len(task_names) == 0:
print("You must have excluded the tasks you specified, exiting.")
return
print(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(args.model, args.model_args, task_names,
num_fewshot=args.num_fewshot, batch_size=args.batch_size,
device=args.device, no_cache=args.no_cache, limit=args.limit,
decontaminate=args.decontaminate, ngrams_path=args.ngrams_path,
ngrams_n_size=args.ngrams_n_size)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
......
......@@ -4,6 +4,8 @@ import json
import jsonlines
import io
import datetime
import mmap
import tqdm
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
......@@ -79,12 +81,63 @@ class TextReader:
def __init__(self, file_path):
self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
while True:
line = self.fh.readline()
line = fh.readline()
if line == -1 or line == "":
break
else :
yield line[:-1]
\ No newline at end of file
else:
yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment