Commit 602d3e20 authored by thefazzer's avatar thefazzer
Browse files

Merge remote-tracking branch 'origin/master' into fazz/refactor-task-coqa

parents 222027b9 01630657
...@@ -10,6 +10,49 @@ The goal of this project is to build a set of tools for evaluating LMs on typica ...@@ -10,6 +10,49 @@ The goal of this project is to build a set of tools for evaluating LMs on typica
2. Removing task val/test data from LM training set 2. Removing task val/test data from LM training set
3. Adding task training data to LM training set 3. Adding task training data to LM training set
### Overview of Tasks
| Task Name |Train|Val|Test| Metrics |
|---------------|-----|---|----|--------------------|
|cola |✓ |✓ |✓ |mcc |
|mnli |✓ |✓ |✓ |acc |
|mnli_mismatched|✓ |✓ |✓ |acc |
|mrpc |✓ |✓ |✓ |acc, f1 |
|rte |✓ |✓ |✓ |acc |
|qnli |✓ |✓ |✓ |acc |
|qqp |✓ |✓ |✓ |acc, f1 |
|sst |✓ |✓ |✓ |acc |
|wnli |✓ |✓ |✓ |acc |
|boolq |✓ |✓ |✓ |acc |
|cb |✓ |✓ |✓ |acc, f1 |
|copa |✓ |✓ |✓ |acc |
|multirc |✓ |✓ |✓ |acc |
|wic |✓ |✓ |✓ |acc |
|wsc |✓ |✓ |✓ |acc |
|lambada | |✓ | |perplexity, accuracy|
|piqa |✓ |✓ | |acc |
|arc_easy |✓ |✓ |✓ |acc |
|arc_challenge |✓ |✓ |✓ |acc |
|hellaswag |✓ |✓ |✓ |acc |
|race |✓ |✓ |✓ |acc |
|webqs |✓ | |✓ |acc |
|wsc273 | | |✓ |acc |
|winogrande |✓ |✓ |✓ |acc |
|anli_r1 |✓ |✓ |✓ |acc |
|anli_r2 |✓ |✓ |✓ |acc |
|anli_r3 |✓ |✓ |✓ |acc |
|arithmetic_2da | |✓ | |acc |
|arithmetic_2ds | |✓ | |acc |
|arithmetic_3da | |✓ | |acc |
|arithmetic_3ds | |✓ | |acc |
|arithmetic_4da | |✓ | |acc |
|arithmetic_4ds | |✓ | |acc |
|arithmetic_5da | |✓ | |acc |
|arithmetic_5ds | |✓ | |acc |
|arithmetic_2dm | |✓ | |acc |
|arithmetic_1dc | |✓ | |acc |
## Usage ## Usage
### Evaluate a task ### Evaluate a task
......
...@@ -176,10 +176,42 @@ class Task(abc.ABC): ...@@ -176,10 +176,42 @@ class Task(abc.ABC):
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)] [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)]
) + "\n\n" ) + "\n\n"
example = self.doc_to_text(doc).strip() example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in doc['choices']
]
return lls
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -235,9 +267,66 @@ def perplexity(items): ...@@ -235,9 +267,66 @@ def perplexity(items):
return math.exp(-mean(items)) return math.exp(-mean(items))
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2 'loglikelihood': 2,
} }
import os
import json
import hashlib
from sqlitedict import SqliteDict
def hash_args(args):
dat = b""
for arg in args:
assert isinstance(arg, str) or isinstance(arg, int)
dat += str(arg).encode()
dat += b"\0"
return hashlib.sha256(dat).hexdigest()
class CachingLM:
def __init__(self, lm, cache_db):
self.lm = lm
self.cache_db = cache_db
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
def __getattr__(self, attr):
def fn(requests):
res = []
remaining_reqs = []
# figure out which ones are cached and which ones are new
for req in requests:
hsh = attr + '_' + hash_args(req)
if hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None: resptr += 1
res[resptr] = r
# caching
hsh = attr + '_' + hash_args(req)
self.dbdict[hsh] = r
return res
return fn
class Request: class Request:
def __init__(self, type, args, index=None): def __init__(self, type, args, index=None):
......
...@@ -12,6 +12,7 @@ class GPT2LM(LM): ...@@ -12,6 +12,7 @@ class GPT2LM(LM):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
......
...@@ -3,6 +3,7 @@ import transformers ...@@ -3,6 +3,7 @@ import transformers
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import time
def get_result(response, ctxlen): def get_result(response, ctxlen):
...@@ -21,6 +22,18 @@ def get_result(response, ctxlen): ...@@ -21,6 +22,18 @@ def get_result(response, ctxlen):
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
def oa_completion(**kwargs):
import openai
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
time.sleep(backoff_time)
backoff_time *= 1.5
class GPT3LM(LM): class GPT3LM(LM):
MAX_LENGTH = 2048 MAX_LENGTH = 2048
...@@ -38,6 +51,9 @@ class GPT3LM(LM): ...@@ -38,6 +51,9 @@ class GPT3LM(LM):
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
# to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>"
self.truncate = truncate self.truncate = truncate
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
...@@ -50,11 +66,12 @@ class GPT3LM(LM): ...@@ -50,11 +66,12 @@ class GPT3LM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
import openai import openai
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)): res = []
for chunk in tqdm(list(utils.chunks(requests, self.REQ_CHUNK_SIZE))):
inps = [] inps = []
ctxlens = [] ctxlens = []
for context, continuation in chunk: for context, continuation in chunk:
print(context)
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
...@@ -63,7 +80,7 @@ class GPT3LM(LM): ...@@ -63,7 +80,7 @@ class GPT3LM(LM):
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
response = openai.Completion.create( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
echo=True, echo=True,
...@@ -85,7 +102,7 @@ class GPT3LM(LM): ...@@ -85,7 +102,7 @@ class GPT3LM(LM):
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):] inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS)) ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
response = openai.Completion.create( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=[inp], prompt=[inp],
max_tokens=self.MAX_GEN_TOKS, max_tokens=self.MAX_GEN_TOKS,
......
...@@ -48,8 +48,8 @@ TASK_REGISTRY = { ...@@ -48,8 +48,8 @@ TASK_REGISTRY = {
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
#"triviaqa": triviaqa.TriviaQA, #"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet "arc_easy": arc.ARCEasy,
# "arc_challenge": arc.ARCChallenge, # not implemented yet "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag, # not implemented yet
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet # "openbookqa": openbookqa.OpenBookQA, # not implemented yet
......
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask from . common import HFTask
class ARCEasy(HFTask): class ARCEasy(HFTask):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
def __init__(self):
super().__init__()
self.data = self.__clean_data()
def __clean_data(self):
""" Resolves various edge cases in the unprocessed HF ARC dataset. """
# NOTE: Some `doc["answerKey"]`s are in numeric string format being one
# of {'1', '2', '3', '4', '5'}. We map them back to letters.
num_to_letter = {'1': 'A', '2': 'B', '3': 'C', '4': 'D', '5': 'E'}
result = {}
for split, data in self.data.items():
result[split] = []
for doc in data:
# Ensure all `answerKey`s and `label`s are in letter format.
doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
doc["choices"]["label"] = [
num_to_letter.get(label, label) for label in doc["choices"]["label"]
]
result[split].append(doc)
return result
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -21,7 +47,8 @@ class ARCEasy(HFTask): ...@@ -21,7 +47,8 @@ class ARCEasy(HFTask):
return "Question: " + doc['question'] + '\nAnswer:' return "Question: " + doc['question'] + '\nAnswer:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['choices']['text'][doc['choices']['label'].index(doc['answerKey'])] index = self.letter_to_num[doc["answerKey"]]
return " " + doc['choices']['text'][index]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -34,9 +61,11 @@ class ARCEasy(HFTask): ...@@ -34,9 +61,11 @@ class ARCEasy(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. ll_choices = []
raise NotImplementedError('Evaluation not implemented') for choice in doc["choices"]["text"]:
ll_choices.append(rf.loglikelihood(ctx, " " + choice)[0])
return ll_choices
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
...@@ -47,8 +76,11 @@ class ARCEasy(HFTask): ...@@ -47,8 +76,11 @@ class ARCEasy(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = self.letter_to_num[doc["answerKey"]]
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -56,8 +88,9 @@ class ARCEasy(HFTask): ...@@ -56,8 +88,9 @@ class ARCEasy(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -65,8 +98,10 @@ class ARCEasy(HFTask): ...@@ -65,8 +98,10 @@ class ARCEasy(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
......
...@@ -32,7 +32,7 @@ class Arithmetic(Task): ...@@ -32,7 +32,7 @@ class Arithmetic(Task):
self._docs = [self.load_doc(json.loads(line)) for line in jsons] self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self): def has_training_docs(self):
return True return False
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -41,10 +41,10 @@ class Arithmetic(Task): ...@@ -41,10 +41,10 @@ class Arithmetic(Task):
return False return False
def training_docs(self): def training_docs(self):
return self._docs return NotImplemented
def validation_docs(self): def validation_docs(self):
return self._docs[:100] return self._docs
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
......
...@@ -61,7 +61,7 @@ class HellaSwag(HFTask): ...@@ -61,7 +61,7 @@ class HellaSwag(HFTask):
raise ValueError( raise ValueError(
"HellaSwag from HF datasets contained an invalid answer key") "HellaSwag from HF datasets contained an invalid answer key")
target = doc['endings'][index] target = doc['endings'][index]
return self.remove_brackets(target) return " " + self.remove_brackets(target)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -75,7 +75,7 @@ class HellaSwag(HFTask): ...@@ -75,7 +75,7 @@ class HellaSwag(HFTask):
""" """
ll_answers = [] ll_answers = []
for i in range(4): for i in range(4):
continuation = self.remove_brackets(doc['endings'][i]) continuation = " " + self.remove_brackets(doc['endings'][i])
ll_answers.append(rf.loglikelihood(ctx, continuation)) ll_answers.append(rf.loglikelihood(ctx, continuation))
return ll_answers return ll_answers
......
...@@ -18,22 +18,22 @@ class LAMBADA(Task): ...@@ -18,22 +18,22 @@ class LAMBADA(Task):
return False return False
def has_validation_docs(self): def has_validation_docs(self):
return False return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def training_docs(self): def training_docs(self):
pass pass
def validation_docs(self): def validation_docs(self):
pass
def test_docs(self):
with open("data/lambada/lambada_test.jsonl") as fh: with open("data/lambada/lambada_test.jsonl") as fh:
for line in fh: for line in fh:
yield json.loads(line) yield json.loads(line)
def test_docs(self):
pass
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc['text'].rsplit(' ', 1)[0] return doc['text'].rsplit(' ', 1)[0]
......
...@@ -46,12 +46,12 @@ class PiQA(Task): ...@@ -46,12 +46,12 @@ class PiQA(Task):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc[0]['goal'] return doc[0]['goal'] + "\n"
def doc_to_target(self, doc): def doc_to_target(self, doc):
#TODO: check if oa uses newline #TODO: check if oa uses newline
rightanswer = int(doc[1]) + 1 rightanswer = int(doc[1]) + 1
return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]]) return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1']) ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
......
import json import json
import random import random
import os import os
from lm_eval.base import Task, rf, mean from lm_eval.base import MultipleChoiceTask, rf, mean
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric from . common import simple_accuracy_metric
import numpy as np import numpy as np
from ..utils import sh from ..utils import sh
class SATAnalogies(Task): class SATAnalogies(MultipleChoiceTask):
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def __init__(self): def __init__(self):
...@@ -61,8 +61,8 @@ class SATAnalogies(Task): ...@@ -61,8 +61,8 @@ class SATAnalogies(Task):
doc = { doc = {
'source': source, 'source': source,
'query': query.split(' ')[:2], 'query': query.split(' ')[:2],
'choices': [c.split(' ')[:2] for c in choices], 'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in choices],
'answer_key': ['a','b','c','d','e'].index(answer_key.strip()), 'gold': ['a','b','c','d','e'].index(answer_key.strip()),
} }
yield doc yield doc
...@@ -72,35 +72,4 @@ class SATAnalogies(Task): ...@@ -72,35 +72,4 @@ class SATAnalogies(Task):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as ".format(*doc['query']) return "{} is to {} as".format(*doc['query'])
def doc_to_target(self, doc):
return "{} is to {}".format(*doc['choices'][doc['answer_key']])
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{} is to {}".format(*doc['choices'][i]))[0]
for i in range(5)
]
return lls
def process_results(self, doc, results):
gold = doc["answer_key"]
acc = 1. if np.argmax(results) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
...@@ -28,10 +28,10 @@ class BoolQ(HFTask): ...@@ -28,10 +28,10 @@ class BoolQ(HFTask):
return "Read the following passages and answer each question with a yes or a no." return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: " return f"{doc['passage']}\nquestion: {doc['question']}\nanswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return yesno(doc['label']) return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -156,12 +156,12 @@ class Copa(HFTask): ...@@ -156,12 +156,12 @@ class Copa(HFTask):
"cause": "because", "cause": "because",
"effect": "therefore", "effect": "therefore",
}[doc["question"]] }[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector} " return doc["premise"].strip()[:-1] + f" {connector}"
def doc_to_target(self, doc): def doc_to_target(self, doc):
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences # Connect the sentences
return self.convert_choice(correct_choice) return " " + self.convert_choice(correct_choice)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
choice1 = " " + self.convert_choice(doc["choice1"]) choice1 = " " + self.convert_choice(doc["choice1"])
...@@ -435,11 +435,10 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -435,11 +435,10 @@ class SGWinogradSchemaChallenge(HFTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
raw_passage = doc["text"] raw_passage = doc["text"]
passage = ( # NOTE: HuggingFace span indices are word-based not character-based.
raw_passage[:doc["span2_index"]] pre = " ".join(raw_passage.split()[:doc["span2_index"]])
+ "*{}*".format(doc["span2_text"]) post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:]
+ raw_passage[doc["span2_index"] + len(doc["span2_text"]):] passage = pre + " *{}*".format(doc['span2_text']) + post
)
noun = doc["span1_text"] noun = doc["span1_text"]
pronoun = doc["span2_text"] pronoun = doc["span2_text"]
text = ( text = (
......
...@@ -4,9 +4,11 @@ import numpy as np ...@@ -4,9 +4,11 @@ import numpy as np
import random import random
import itertools import itertools
import collections import collections
import logging
from lm_eval import models, tasks, evaluator from lm_eval import models, tasks, evaluator, base
logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -18,14 +20,19 @@ def parse_args(): ...@@ -18,14 +20,19 @@ def parse_args():
parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--cache', action="store_true")
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args) lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
......
...@@ -4,4 +4,5 @@ datasets>=1.2.1 ...@@ -4,4 +4,5 @@ datasets>=1.2.1
click>=7.1 click>=7.1
scikit-learn>=0.24.1 scikit-learn>=0.24.1
torch>=1.7 torch>=1.7
transformers>=4.1 transformers>=4.1
\ No newline at end of file sqlitedict==1.6.0
\ No newline at end of file
from lm_eval import tasks
from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Metrics"]
values = []
def chk(tf):
if tf:
return '✓'
else:
return ' '
for tname, Task in tasks.TASK_REGISTRY.items():
task = Task()
values.append([tname,chk(task.has_training_docs()),chk(task.has_validation_docs()),chk(task.has_test_docs()),', '.join(task.aggregation().keys())])
writer.value_matrix = values
print(writer.dumps())
\ No newline at end of file
...@@ -37,8 +37,8 @@ def main(): ...@@ -37,8 +37,8 @@ def main():
iters = [] iters = []
for set in args.sets.split(","): for set in args.sets.split(","):
if set == 'train' and task.has_train_docs(): if set == 'train' and task.has_training_docs():
docs = task.train_docs() docs = task.training_docs()
if set == 'val' and task.has_validation_docs(): if set == 'val' and task.has_validation_docs():
docs = task.validation_docs() docs = task.validation_docs()
if set == 'test' and task.has_test_docs(): if set == 'test' and task.has_test_docs():
......
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import lm_eval.models as models import lm_eval.models as models
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
import random
import pytest import pytest
...@@ -11,4 +12,21 @@ import pytest ...@@ -11,4 +12,21 @@ import pytest
def test_evaluator(taskname, Task): def test_evaluator(taskname, Task):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model('dummy')()
def ll_fn(reqs):
for ctx, cont in reqs:
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
res = []
random.seed(42)
for _ in reqs:
res.append((-random.random(), False))
return res
lm.loglikelihood = ll_fn
evaluator.evaluate(lm, task_dict, False, 0, 10) evaluator.evaluate(lm, task_dict, False, 0, 10)
\ No newline at end of file
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import lm_eval.base as base import lm_eval.base as base
from unittest.mock import MagicMock
from itertools import islice from itertools import islice
import pytest import pytest
...@@ -43,6 +42,10 @@ def test_documents_and_requests(taskname, Task): ...@@ -43,6 +42,10 @@ def test_documents_and_requests(taskname, Task):
assert isinstance(txt, str) assert isinstance(txt, str)
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment