Commit f88bb827 authored by Jason Phang's avatar Jason Phang
Browse files

lib

parent cf80f340
import base
import nlp
def yesno(x):
if x: return 'yes'
else: return 'no'
def mean(x):
return sum(x) / len(x)
class BoolQ(base.Dataset):
def __init__(self):
self.dataset = nlp.load_dataset('boolq')
def training_docs(self):
yield from self.dataset['train']
def validation_docs(self):
yield from self.dataset['validation']
def test_docs(self):
return []
def fewshot_prefix(self):
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc, include_target=True):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: " + (yesno(doc['answer']) if include_target else "")
def evaluate(self, docs, lm, provide_description, num_fewshot):
acc = []
for doc in docs:
ctx = '\n\n'.join(map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))) + '\n\n'
ctx += self.doc_to_text(doc, include_target=False).strip()
ctx = ((self.fewshot_description() + "\n\n") if provide_description else "") + ctx
ans = lm.loglikelihood(ctx, 'yes') > lm.loglikelihood(ctx, 'no')
acc.append(int(ans == doc['answer']))
return mean(acc)
\ No newline at end of file
import importlib
import os
from ..base import Registry
from lm_eval.base import Registry
MODEL_REGISTRY = Registry(registry_name="models")
# Load all modules in models directory to populate registry
......@@ -13,7 +13,7 @@ for file in os.listdir(models_dir):
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_evaluation_harness.models.' + module_name)
module = importlib.import_module('lm_eval.models.' + module_name)
def get_model(model_name):
......
import transformers
import torch
from ..base import LM
from lm_eval.base import LM
from . import MODEL_REGISTRY
......
import transformers
import torch
import torch.nn.functional as F
from ..base import LM
from .. import utils
from lm_eval.base import LM
from lm_eval import utils
from . import MODEL_REGISTRY
......
import os
import openai
import transformers
from ..base import LM
from .. import utils
from lm_eval.base import LM
from lm_eval import utils
from . import MODEL_REGISTRY
......@@ -15,7 +15,7 @@ class GPT3LM(LM):
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod
def create_from_args(cls, arg_string):
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
......@@ -37,6 +37,7 @@ class GPT3LM(LM):
response = openai.Completion.create(
engine=self.engine,
prompt=full_text,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
)
......
import importlib
import os
from ..base import Registry
from lm_eval.base import Registry
TASK_REGISTRY = Registry(registry_name="tasks")
# Load all modules in models directory to populate registry
......@@ -13,7 +13,7 @@ for file in os.listdir(tasks_dir):
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_evaluation_harness.tasks.' + module_name)
module = importlib.import_module('lm_eval.tasks.' + module_name)
ALL_TASKS = sorted(list(TASK_REGISTRY.registry))
......
import json
import random
from ..base import Dataset
from lm_eval.base import Dataset
from . import TASK_REGISTRY
......
import argparse
import json
import models
import tasks
from lm_eval import models, tasks
def parse_args():
parser = argparse.ArgumentParser()
......@@ -10,32 +10,34 @@ def parse_args():
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--new_fewshot', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=1)
return parser.parse_args()
def main():
args = parse_args()
model = models.get_model(args.model).create_from_arg_string(args.model_args)
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_list = {
task_dict = {
task_name: tasks.get_task(task_name)()
for task_name in task_names
}
results = {}
for task_name, task in task_list:
for task_name, task in task_dict.items():
if not task.has_validation_docs():
continue
result = task.evaluate(
docs=task.validation_docs(),
lm=lm,
provide_description=args.provide_description,
num_fewshot=args.new_fewshot,
num_fewshot=args.num_fewshot,
)
results[task_name] = result
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()
\ No newline at end of file
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment