Commit e41a082c authored by Leo Gao's avatar Leo Gao
Browse files

Update

parent 76e65788
...@@ -26,7 +26,7 @@ class LM(abc.ABC): ...@@ -26,7 +26,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def gen_greedy(self, requests): def greedy_until(self, requests):
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list :param requests: list
...@@ -104,7 +104,11 @@ class Dataset(abc.ABC): ...@@ -104,7 +104,11 @@ class Dataset(abc.ABC):
return random.sample(self._traindocs, k) return random.sample(self._traindocs, k)
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc):
pass
@abc.abstractmethod
def doc_to_target(self, doc):
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -113,13 +117,14 @@ class Dataset(abc.ABC): ...@@ -113,13 +117,14 @@ class Dataset(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a dict with the following format: """Take a single document and the LM results and evaluates, returning a
list of dicts, each with the following format:
{ {
"submetric": str, "submetric": str,
"value": float, "value": float,
"higher_is_better": bool, "higher_is_better": bool,
"aggregation": (list -> float), "aggregation": ([float] -> float),
} }
* `submetric` should be the name of the metric * `submetric` should be the name of the metric
...@@ -138,10 +143,12 @@ class Dataset(abc.ABC): ...@@ -138,10 +143,12 @@ class Dataset(abc.ABC):
def fewshot_context(self, doc, num_fewshot, provide_description): def fewshot_context(self, doc, num_fewshot, provide_description):
raw_description = self.fewshot_description() raw_description = self.fewshot_description()
description = (raw_description + "\n===\n\n") if provide_description and raw_description else "" description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
labeled_examples = "\n\n".join( labeled_examples = "\n\n".join(
map(self.doc_to_text, self.fewshot_examples(k=num_fewshot)) [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_examples(k=num_fewshot)]
) + "\n\n" ) + "\n\n"
example = self.doc_to_text(doc, include_target=False).strip()
example = self.doc_to_text(doc).strip()
return description + labeled_examples + example return description + labeled_examples + example
...@@ -153,12 +160,12 @@ def median(arr): ...@@ -153,12 +160,12 @@ def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
Request = collections.namedtuple('Request', ('type', 'args', 'kwargs')) Request = collections.namedtuple('Request', ('type', 'args'))
class RequestFactory: class RequestFactory:
def __getattr__(self, attr): def __getattr__(self, attr):
def fn(*args, **kwargs): def fn(*args):
return Request(attr, args, kwargs) return Request(attr, args)
return fn return fn
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
...@@ -20,7 +21,7 @@ class GPT2LM(LM): ...@@ -20,7 +21,7 @@ class GPT2LM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
# TODO: vectorize properly # TODO: vectorize properly
for context, continuation in requests: for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
...@@ -35,6 +36,6 @@ class GPT2LM(LM): ...@@ -35,6 +36,6 @@ class GPT2LM(LM):
return res return res
def gen_greedy(self, requests): def greedy_until(self, requests):
# TODO: implement # TODO: implement
pass pass
\ No newline at end of file
import numpy as np import numpy as np
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno from . common import HFTask, simple_accuracy_metric, yesno
from lm_eval.base import rf, mean
class BoolQ(HFTask): class BoolQ(HFTask):
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
...@@ -19,21 +19,33 @@ class BoolQ(HFTask): ...@@ -19,21 +19,33 @@ class BoolQ(HFTask):
def fewshot_description(self): def fewshot_description(self):
return "Read the following passages and answer each question with a yes or a no." return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: " \ return f"{doc['passage']}\nquestion: {doc['question']}\nanswer: "
+ (yesno(doc['label']) if include_target else "")
def doc_to_target(self, doc):
return yesno(doc['label'])
def evaluate(self, docs, lm, provide_description, num_fewshot): def construct_requests(self, ctx):
golds = [doc["label"] for doc in docs]
preds = [] ll_yes = rf.loglikelihood(ctx, ' yes')
for doc in docs: ll_no = rf.loglikelihood(ctx, ' no')
ctx = self.fewshot_context(
doc=doc, return ll_yes, ll_no
provide_description=provide_description,
num_fewshot=num_fewshot, def process_results(self, doc, results):
) ll_yes, ll_no = results
preds.append(lm.loglikelihood(ctx, ' yes') > lm.loglikelihood(ctx, ' no')) gold = doc["label"]
return simple_accuracy_metric(preds=preds, golds=golds)
acc = 1. if (ll_yes > ll_no) == gold else 0.
return [
{
"submetric": "acc",
"value": acc
"higher_is_better": True,
"aggregation": mean
}
]
class CommitmentBank(HFTask): class CommitmentBank(HFTask):
......
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
import numpy as np import numpy as np
import random import random
import itertools import itertools
import collections
from lm_eval import models, tasks from lm_eval import models, tasks
...@@ -30,17 +31,43 @@ def main(): ...@@ -30,17 +31,43 @@ def main():
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
task_dict_items = list(task_dict.items())
results = {} results = {}
for task_name, task in task_dict.items():
requests = collections.defaultdict(list)
requests_lengths = collections.defaultdict(list)
for task_name, task in task_dict_items:
# TODO: fall back to test docs
if not task.has_validation_docs(): if not task.has_validation_docs():
continue continue
result = task.evaluate(
docs=itertools.isslice(task.validation_docs(), 0, args.limit), for doc in itertools.islice(task.validation_docs(), 0, args.limit):
lm=lm, ctx = task.fewshot_context(
provide_description=args.provide_description, doc=doc,
num_fewshot=args.num_fewshot, provide_description=args.provide_description,
) num_fewshot=args.num_fewshot,
results[task_name] = result )
reqs = task.construct_requests(ctx)
lengths = collections.defaultdict(int)
for req in reqs:
requests[req.type].append(req)
lengths[req.type] += 1
for type, ct in lengths.items():
requests_lengths[type].append(ct)
# TODO: finish implementation
for reqname, reqs in requests.items():
lm_res = getattr(lm, reqname)([req.args for req in reqs])
for task_name, task in task_dict_items:
if not task.has_validation_docs():
continue
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment