Commit e41a082c authored by Leo Gao's avatar Leo Gao
Browse files

Update

parent 76e65788
......@@ -26,7 +26,7 @@ class LM(abc.ABC):
pass
@abc.abstractmethod
def gen_greedy(self, requests):
def greedy_until(self, requests):
"""Generate greedily until a stopping sequence
:param requests: list
......@@ -104,7 +104,11 @@ class Dataset(abc.ABC):
return random.sample(self._traindocs, k)
@abc.abstractmethod
def doc_to_text(self, doc, include_target=True):
def doc_to_text(self, doc):
pass
@abc.abstractmethod
def doc_to_target(self, doc):
pass
@abc.abstractmethod
......@@ -113,13 +117,14 @@ class Dataset(abc.ABC):
@abc.abstractmethod
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a dict with the following format:
"""Take a single document and the LM results and evaluates, returning a
list of dicts, each with the following format:
{
"submetric": str,
"value": float,
"higher_is_better": bool,
"aggregation": (list -> float),
"aggregation": ([float] -> float),
}
* `submetric` should be the name of the metric
......@@ -138,10 +143,12 @@ class Dataset(abc.ABC):
def fewshot_context(self, doc, num_fewshot, provide_description):
raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
labeled_examples = "\n\n".join(
map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)]
) + "\n\n"
example = self.doc_to_text(doc, include_target=False).strip()
example = self.doc_to_text(doc).strip()
return description + labeled_examples + example
......@@ -153,12 +160,12 @@ def median(arr):
return arr[len(arr) // 2]
Request = collections.namedtuple('Request', ('type', 'args', 'kwargs'))
Request = collections.namedtuple('Request', ('type', 'args'))
class RequestFactory:
def __getattr__(self, attr):
def fn(*args, **kwargs):
return Request(attr, args, kwargs)
def fn(*args):
return Request(attr, args)
return fn
......
......@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
class GPT2LM(LM):
......@@ -20,7 +21,7 @@ class GPT2LM(LM):
def loglikelihood(self, requests):
res = []
# TODO: vectorize properly
for context, continuation in requests:
for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
......@@ -35,6 +36,6 @@ class GPT2LM(LM):
return res
def gen_greedy(self, requests):
def greedy_until(self, requests):
# TODO: implement
pass
\ No newline at end of file
import numpy as np
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
from lm_eval.base import rf, mean
class BoolQ(HFTask):
DATASET_PATH = "super_glue"
......@@ -19,21 +19,33 @@ class BoolQ(HFTask):
def fewshot_description(self):
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc, include_target=True):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: " \
+ (yesno(doc['label']) if include_target else "")
def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: "
def doc_to_target(self, doc):
return yesno(doc['label'])
def evaluate(self, docs, lm, provide_description, num_fewshot):
golds = [doc["label"] for doc in docs]
preds = []
for doc in docs:
ctx = self.fewshot_context(
doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot,
)
preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no'))
return simple_accuracy_metric(preds=preds, golds=golds)
def construct_requests(self, ctx):
ll_yes = rf.loglikelihood(ctx, ' yes')
ll_no = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return [
{
"submetric": "acc",
"value": acc
"higher_is_better": True,
"aggregation": mean
}
]
class CommitmentBank(HFTask):
......
......@@ -3,6 +3,7 @@ import json
import numpy as np
import random
import itertools
import collections
from lm_eval import models, tasks
......@@ -30,17 +31,43 @@ def main():
else:
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
task_dict_items = list(task_dict.items())
results = {}
for task_name, task in task_dict.items():
requests = collections.defaultdict(list)
requests_lengths = collections.defaultdict(list)
for task_name, task in task_dict_items:
# TODO: fall back to test docs
if not task.has_validation_docs():
continue
result = task.evaluate(
docs=itertools.isslice(task.validation_docs(), 0, args.limit),
lm=lm,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
)
results[task_name] = result
for doc in itertools.islice(task.validation_docs(), 0, args.limit):
ctx = task.fewshot_context(
doc=doc,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
)
reqs = task.construct_requests(ctx)
lengths = collections.defaultdict(int)
for req in reqs:
requests[req.type].append(req)
lengths[req.type] += 1
for type, ct in lengths.items():
requests_lengths[type].append(ct)
# TODO: finish implementation
for reqname, reqs in requests.items():
lm_res = getattr(lm, reqname)([req.args for req in reqs])
for task_name, task in task_dict_items:
if not task.has_validation_docs():
continue
dumped = json.dumps(results, indent=2)
print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment