Commit 487f7811 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

prelim. multiple choice support

parent e7f49cca
......@@ -22,4 +22,7 @@ HIGHER_IS_BETTER_REGISTRY = {
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
}
\ No newline at end of file
......@@ -11,6 +11,7 @@ class Instance:
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: str = None
doc_id: str = None
repeats: str = None
......
......@@ -10,7 +10,10 @@ import evaluate
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
}
def register_metric(name):
......@@ -45,6 +48,7 @@ searching in HF Evaluate library...")
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
......@@ -155,6 +159,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
@register_metric("perplexity")
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
......
import abc
from typing import Union
from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(name):
# TODO: should fairseq/elk be cited for this design pattern?
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
assert (
issubclass(cls, LM)
), f"Model '{name}' ({cls.__name__}) must extend LM class"
for name in names:
assert (
issubclass(cls, LM)
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model!"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
MODEL_REGISTRY[name] = cls
return cls
return decorate
......
......@@ -5,13 +5,15 @@ import re
import evaluate
import random
import itertools
import functools
import datasets
import numpy as np
from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api import HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils
......@@ -36,10 +38,11 @@ class TaskConfig(dict):
doc_to_text: str = ""
doc_to_target: str = ""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None
num_fewshot: int = 0
batch_size: int = 1
repeats: int = 1
metric_list: str = None
gold_alias: str = None
output_type: str = "greedy_until"
......@@ -122,7 +125,8 @@ class Task(abc.ABC):
filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(self.training_docs(), self, rnd=random.Random()) # TODO: pass the correct docs in here
self.sampler = samplers.Sampler(self.fewshot_docs(), self, rnd=random.Random()) # TODO: pass the correct docs in here
def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset.
......@@ -193,6 +197,19 @@ class Task(abc.ABC):
"""
return []
def fewshot_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self.has_training_docs():
return self.training_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
# TODO: should we allow this case to occur? / should raise a warning here
return self.test_docs()
def _process_doc(self, doc):
"""
Override this to process (detokenize, strip, replace, etc.) individual
......@@ -336,33 +353,33 @@ class Task(abc.ABC):
labeled_examples = ""
else:
# labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
# if self.has_training_docs():
# fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
# else:
# if self._fewshot_docs is None:
# self._fewshot_docs = list(
# self.validation_docs()
# if self.has_validation_docs()
# else self.test_docs()
# )
# fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# # get rid of the doc that's the one we're evaluating, if it's in the fewshot
# fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# labeled_examples = (
# "\n\n".join(
# [
# self.doc_to_text(doc) + self.doc_to_target(doc)
# for doc in fewshotex
# ]
# )
# + "\n\n"
# )
example = self.doc_to_text(doc)
return labeled_examples + example
......@@ -376,7 +393,7 @@ class Task(abc.ABC):
class ConfigurableTask(Task):
VERSION = "2.0"
OUTPUT_TYPE = "greedy_until"
OUTPUT_TYPE = None
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
......@@ -432,6 +449,8 @@ class ConfigurableTask(Task):
for name, components in self._config.get("filters", [["none", ["take_first"]]]):
filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
def has_training_docs(self):
if self._config.training_split is not None:
......@@ -463,6 +482,13 @@ class ConfigurableTask(Task):
if self._config.test_split is not None:
return self.dataset[self._config.test_split]
def fewshot_docs(self):
if self._config.fewshot_split:
return self.dataset[self._config.fewshot_split]
else:
# TODO: warn user if fewshot split isn't explicitly set
return super().fewshot_docs()
def should_decontaminate(self):
return self._config.should_decontaminate
......@@ -497,6 +523,19 @@ class ConfigurableTask(Task):
arguments=(ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
import ast
return [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
id_=i,
**kwargs,
)
for i, choice in enumerate(ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc)))
# we pass the user-defined answer_choices var (in aliases) and echo the result. TODO: any cleaner way to do this?
]
elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, self._config.delimiter)
......@@ -504,6 +543,7 @@ class ConfigurableTask(Task):
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
id_=0,
**kwargs
)
......@@ -516,6 +556,22 @@ class ConfigurableTask(Task):
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
pass
elif self.OUTPUT_TYPE == "multiple_choice":
lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) # TODO: if `gold` here is an integer/ds label obj, map it to answer_choice
# TODO: remove dependence on "gold" and "choices" columns
acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
# TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
result_dict = {
"acc": acc,
"acc_norm": acc_norm,
}
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
......@@ -531,6 +587,10 @@ class ConfigurableTask(Task):
)
result_dict[key] = _dict[key]
else:
raise ValueError(f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'"
)
return result_dict
......@@ -558,11 +618,6 @@ class MultipleChoiceTask(Task):
**kwargs,
)
for i, choice in enumerate(doc["choices"])]
#lls = [
# rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
# ]
# return lls
def process_results(self, doc, results):
results = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -668,19 +723,22 @@ class PerplexityTask(Task, abc.ABC):
TASK_REGISTRY = {}
ALL_TASKS = []
def register_task(name):
def register_task(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task!"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
return cls
return decorate
......
......@@ -145,7 +145,8 @@ def evaluate(
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task.build_all_requests(limit=limit)
# aggregate Instances by LM method requested to get output.
requests[task.OUTPUT_TYPE].extend(task.instances)
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
### Run LM on inputs, get all outputs ###
# execute each type of request
......
......@@ -9,7 +9,7 @@ from lm_eval import utils
from lm_eval.api.model import LM, register_model
@register_model("hf-causal")
@register_model("hf-causal", "gpt2")
class HFLM(LM):
def __init__(
self,
......
dataset_path: ai2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
doc_to_text: "Q: {{question}}\nA:"
doc_to_target: "{% set answer_choices = doc['choices']['text'] %}{{answer_choices[int(doc['answerKey']) - 1]}}"
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}"
metric_list:
- metric: exact_match
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
\ No newline at end of file
dataset_path: EleutherAI/lambada_openai
dataset_name: default
output_type: "loglikelihood"
output_type: loglikelihood
test_split: test
template_aliases: "{% set hypo = hypothesis %}"
template_aliases: ""
doc_to_text: "{{text.split(' ')[:-1]|join(' ')}}"
doc_to_target: "{{' '+text.split(' ')[-1]}}"
should_decontaminate: true
......@@ -12,5 +12,5 @@ metric_list:
aggregation: perplexity
higher_is_better: true
- metric: accuracy
aggregation: perplexity
aggregation: mean
higher_is_better: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment