Commit c4c20ff5 authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit stuff

parent e56b950a
...@@ -45,7 +45,7 @@ python main.py \ ...@@ -45,7 +45,7 @@ python main.py \
--device cuda:0 --device cuda:0
``` ```
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partialy trained checkpoints: Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints:
```bash ```bash
python main.py \ python main.py \
...@@ -64,8 +64,8 @@ To use with [PEFT](https://github.com/huggingface/peft), take the call you would ...@@ -64,8 +64,8 @@ To use with [PEFT](https://github.com/huggingface/peft), take the call you would
python main.py \ python main.py \
--model hf-causal \ --model hf-causal \
--model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \ --model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \ --tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0 --device cuda:0
``` ```
Our library also supports the OpenAI API: Our library also supports the OpenAI API:
...@@ -78,7 +78,7 @@ python main.py \ ...@@ -78,7 +78,7 @@ python main.py \
--tasks lambada_openai,hellaswag --tasks lambada_openai,hellaswag
``` ```
While this functionality is only officially mantained for the official OpenAI API, it tends to also work for other hosting services that use the same API such as [goose.ai](goose.ai) with minor modification. We also have an implementation for the [TextSynth](https://textsynth.com/index.html) API, using `--model textsynth`. While this functionality is only officially maintained for the official OpenAI API, it tends to also work for other hosting services that use the same API such as [goose.ai](goose.ai) with minor modification. We also have an implementation for the [TextSynth](https://textsynth.com/index.html) API, using `--model textsynth`.
To verify the data integrity of the tasks you're performing in addition to running the tasks themselves, you can use the `--check_integrity` flag: To verify the data integrity of the tasks you're performing in addition to running the tasks themselves, you can use the `--check_integrity` flag:
...@@ -116,7 +116,7 @@ When reporting eval harness results, please also report the version of each task ...@@ -116,7 +116,7 @@ When reporting eval harness results, please also report the version of each task
## Test Set Decontamination ## Test Set Decontamination
To address concerns about train / test contamination, we provide utilities for comparing results on a benchmark using only the data points nto found in the model trainign set. Unfortunately, outside of models trained on the Pile ans C4, its very rare that people who train models disclose the contents of the training data. However this utility can be useful to evaluate models you have trained on private data, provided you are willing to pre-compute the necessary indices. We provide computed indices for 13-gram exact match deduplication against the Pile, and plan to add additional precomputed dataset indices in the future (including C4 and min-hash LSH deduplication). To address concerns about train / test contamination, we provide utilities for comparing results on a benchmark using only the data points nto found in the model trainign set. Unfortunately, outside of models trained on the Pile and C4, its very rare that people who train models disclose the contents of the training data. However this utility can be useful to evaluate models you have trained on private data, provided you are willing to pre-compute the necessary indices. We provide computed indices for 13-gram exact match deduplication against the Pile, and plan to add additional precomputed dataset indices in the future (including C4 and min-hash LSH deduplication).
For details on text decontamination, see the [decontamination guide](./docs/decontamination.md). For details on text decontamination, see the [decontamination guide](./docs/decontamination.md).
......
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
\ No newline at end of file
from . import metrics from . import metrics
\ No newline at end of file
...@@ -3,9 +3,10 @@ from typing import List ...@@ -3,9 +3,10 @@ from typing import List
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
class Filter: class Filter:
""" """
Filter classes operate on a per-task level. Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`) They take all model outputs (`instance.resps` for all `task.instances`)
across all instances of a task, and perform operations. across all instances of a task, and perform operations.
In a single run, one can configure any number of separate filters or lists of filters. In a single run, one can configure any number of separate filters or lists of filters.
...@@ -25,30 +26,33 @@ class Filter: ...@@ -25,30 +26,33 @@ class Filter:
[<filtered resps for instance 0>, <filtered resps for instance 1>] [<filtered resps for instance 0>, <filtered resps for instance 1>]
""" """
return resps return resps
@dataclass @dataclass
class FilterEnsemble: class FilterEnsemble:
""" """
FilterEnsemble creates a pipeline applying multiple filters. FilterEnsemble creates a pipeline applying multiple filters.
Its intended usage is to stack multiple post-processing steps in order. Its intended usage is to stack multiple post-processing steps in order.
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
pipeline separately. pipeline separately.
""" """
name: str
name: str
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance]): def apply(self, instances: List[Instance]):
resps = [inst.resps for inst in instances] # operate just on the model responses resps = [
inst.resps for inst in instances
] # operate just on the model responses
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
out = f.apply(resps) out = f.apply(resps)
resps = out # TODO: handle the case where a filter returns multiple "buckets" resps = (
out # TODO: handle the case where a filter returns multiple "buckets"
)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps): for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Tuple
@dataclass @dataclass
class Instance: class Instance:
request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"] request_type: str = Literal[
"loglikelihood", "loglikelihood_rolling", "greedy_until"
]
doc: dict = None doc: dict = None
arguments: tuple = None arguments: tuple = None
idx: int = None idx: int = None
metadata: tuple = Tuple[str, int, int] # TODO: better typehints here metadata: tuple = Tuple[str, int, int] # TODO: better typehints here
resps: list = field(default_factory=list) resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict) filtered_resps: dict = field(default_factory=dict)
...@@ -19,10 +22,12 @@ class Instance: ...@@ -19,10 +22,12 @@ class Instance:
def __post_init__(self): def __post_init__(self):
# unpack metadata field # unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
@property @property
def args(self): def args(self):
""" """
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
...@@ -26,7 +26,6 @@ HIGHER_IS_BETTER_REGISTRY = { ...@@ -26,7 +26,6 @@ HIGHER_IS_BETTER_REGISTRY = {
"bleu": True, "bleu": True,
"chrf": True, "chrf": True,
"ter": False, "ter": False,
"acc": True, "acc": True,
"acc_norm": True, "acc_norm": True,
"acc_mutual_info": True, "acc_mutual_info": True,
...@@ -35,6 +34,7 @@ HIGHER_IS_BETTER_REGISTRY = { ...@@ -35,6 +34,7 @@ HIGHER_IS_BETTER_REGISTRY = {
"bits_per_byte": False, "bits_per_byte": False,
} }
def register_metric(name): def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
...@@ -44,7 +44,7 @@ def register_metric(name): ...@@ -44,7 +44,7 @@ def register_metric(name):
METRIC_REGISTRY[name] = fn METRIC_REGISTRY[name] = fn
return fn return fn
return decorate return decorate
...@@ -54,12 +54,14 @@ def get_metric(name): ...@@ -54,12 +54,14 @@ def get_metric(name):
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
except KeyError: except KeyError:
# TODO: change this print to logging? # TODO: change this print to logging?
print(f"Could not find registered metric '{name}' in lm-eval, \ print(
searching in HF Evaluate library...") f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try: try:
metric_object = evaluate.load(name) metric_object = evaluate.load(name)
return metric_object.compute return metric_object.compute
except: except Exception:
raise Warning( raise Warning(
"{} not found in the evaluate library!".format(name), "{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
...@@ -75,7 +77,7 @@ def register_aggregation(name): ...@@ -75,7 +77,7 @@ def register_aggregation(name):
AGGREGATION_REGISTRY[name] = fn AGGREGATION_REGISTRY[name] = fn
return fn return fn
return decorate return decorate
......
...@@ -6,14 +6,15 @@ from lm_eval import utils ...@@ -6,14 +6,15 @@ from lm_eval import utils
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
def register_model(*names): def register_model(*names):
# either pass a list or a single alias. # either pass a list or a single alias.
# function receives them as a tuple of strings # function receives them as a tuple of strings
def decorate(cls): def decorate(cls):
for name in names: for name in names:
assert ( assert issubclass(
issubclass(cls, LM) cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class" ), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert ( assert (
...@@ -22,7 +23,7 @@ def register_model(*names): ...@@ -22,7 +23,7 @@ def register_model(*names):
MODEL_REGISTRY[name] = cls MODEL_REGISTRY[name] = cls
return cls return cls
return decorate return decorate
......
...@@ -5,6 +5,7 @@ group_registry = {} ...@@ -5,6 +5,7 @@ group_registry = {}
task2func_index = {} task2func_index = {}
func2task_index = {} func2task_index = {}
def register_task(name): def register_task(name):
def wrapper(func): def wrapper(func):
...@@ -15,16 +16,16 @@ def register_task(name): ...@@ -15,16 +16,16 @@ def register_task(name):
return wrapper return wrapper
def register_group(name): def register_group(name):
def wrapper(func): def wrapper(func):
func_name = func2task_index[func.__name__] func_name = func2task_index[func.__name__]
if name in group_registry: if name in group_registry:
group_registry[name].append( group_registry[name].append(func_name)
func_name
)
else: else:
group_registry[name] = [func_name] group_registry[name] = [func_name]
return func return func
return wrapper return wrapper
class Sampler: class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None): def __init__(self, docs, task, fewshot_indices=None, rnd=None):
self.rnd = rnd self.rnd = rnd
...@@ -12,15 +9,18 @@ class Sampler: ...@@ -12,15 +9,18 @@ class Sampler:
self.delimiter = self.config.delimiter self.delimiter = self.config.delimiter
self.docs = docs # HF dataset split, provided by task._fewshot_docs() self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot): def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluting on # draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs # draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples) fewshotex = self.sample(n_samples)
...@@ -28,16 +28,16 @@ class Sampler: ...@@ -28,16 +28,16 @@ class Sampler:
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating? # TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = ( labeled_examples = (
self.delimiter.join( self.delimiter.join(
[ [
self.task.doc_to_text(doc) + self.task.doc_to_target(doc) self.task.doc_to_text(doc) + self.task.doc_to_target(doc)
for doc in selected_docs for doc in selected_docs
] ]
)
+ self.delimiter
) )
+ self.delimiter
)
# only returns the fewshot context! Does not append the document, do this outside the object # only returns the fewshot context! Does not append the document, do this outside the object
return labeled_examples return labeled_examples
...@@ -51,25 +51,22 @@ class Sampler: ...@@ -51,25 +51,22 @@ class Sampler:
class BalancedSampler(Sampler): class BalancedSampler(Sampler):
def sample(self, n): def sample(self, n):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
""" """
pass pass
class ManualSampler(Sampler):
class ManualSampler(Sampler):
def sample(self, n): def sample(self, n):
""" """ """
pass
"""
pass
# TODO: how should we do design here? might be better to have a single sampler and pass more kwargs at init. # TODO: how should we do design here? might be better to have a single sampler and pass more kwargs at init.
# Depends what's easier for new user to add own functionality on top of # Depends what's easier for new user to add own functionality on top of
# types of sampler: # types of sampler:
......
...@@ -19,9 +19,15 @@ from lm_eval import utils ...@@ -19,9 +19,15 @@ from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY, METRIC_REGISTRY,
get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte AGGREGATION_REGISTRY,
) HIGHER_IS_BETTER_REGISTRY,
get_metric,
get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
...@@ -35,15 +41,17 @@ class TaskConfig(dict): ...@@ -35,15 +41,17 @@ class TaskConfig(dict):
group: str = None group: str = None
names: str = None names: str = None
reference: str = None reference: str = None
task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0] task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
)
base_task: str = None base_task: str = None
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
training_split: str = None training_split: str = None
validation_split: str = None validation_split: str = None
test_split: str = None test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None template_aliases: str = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
...@@ -57,18 +65,20 @@ class TaskConfig(dict): ...@@ -57,18 +65,20 @@ class TaskConfig(dict):
output_type: str = "greedy_until" output_type: str = "greedy_until"
delimiter: str = "\n\n" delimiter: str = "\n\n"
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
normalization: str = None # TODO: add length-normalization of various types, mutual info normalization: str = (
None # TODO: add length-normalization of various types, mutual info
)
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
use_prompt: str = None use_prompt: str = None
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self): def __post_init__(self):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
# field names in prompt # field names in prompt
if self.template_aliases != None: if self.template_aliases is not None:
if type(self.doc_to_text) == str: if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text self.doc_to_text = self.template_aliases + self.doc_to_text
...@@ -103,6 +113,7 @@ class Task(abc.ABC): ...@@ -103,6 +113,7 @@ class Task(abc.ABC):
DATASET_NAME: str = None DATASET_NAME: str = None
OUTPUT_TYPE: str = None OUTPUT_TYPE: str = None
def __init__( def __init__(
self, self,
data_dir=None, data_dir=None,
...@@ -141,12 +152,15 @@ class Task(abc.ABC): ...@@ -141,12 +152,15 @@ class Task(abc.ABC):
if not hasattr(self, "_filters"): if not hasattr(self, "_filters"):
self._filters = [] self._filters = []
for name, components in self._config.get("filters", [["none", ["take_first"]]]): for name, components in self._config.get(
"filters", [["none", ["take_first"]]]
):
filter_pipeline = build_filter_ensemble(name, components) filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
...@@ -230,7 +244,7 @@ class Task(abc.ABC): ...@@ -230,7 +244,7 @@ class Task(abc.ABC):
eval_logger.warning( eval_logger.warning(
"has_training_docs and has_validation_docs are False" "has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended." "using test_docs but this is not recommended."
) )
return self.test_docs() return self.test_docs()
def _process_doc(self, doc): def _process_doc(self, doc):
...@@ -283,19 +297,24 @@ class Task(abc.ABC): ...@@ -283,19 +297,24 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
instances = [] instances = []
for doc_id, doc in enumerate(itertools.islice(docs, 0, limit) if limit else docs): for doc_id, doc in enumerate(
itertools.islice(docs, 0, limit) if limit else docs
):
# sample fewshot context # sample fewshot context
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random() doc, self._config.num_fewshot, rnd=random.Random()
) )
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute # TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self._config["task_name"], doc_id, self._config.repeats)) inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self._config["task_name"], doc_id, self._config.repeats),
)
if not isinstance(inst, list): if not isinstance(inst, list):
inst = [inst] inst = [inst]
instances.extend(inst) instances.extend(inst)
self._instances = instances self._instances = instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!" assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
...@@ -316,7 +335,7 @@ class Task(abc.ABC): ...@@ -316,7 +335,7 @@ class Task(abc.ABC):
whichever is the main split used. whichever is the main split used.
:param repeats: int :param repeats: int
TODO: update this docstring TODO: update this docstring
The number of times each instance in a dataset is inferred on. Defaults to 1, The number of times each instance in a dataset is inferred on. Defaults to 1,
can be increased for techniques like majority voting. can be increased for techniques like majority voting.
""" """
pass pass
...@@ -428,11 +447,7 @@ class ConfigurableTask(Task): ...@@ -428,11 +447,7 @@ class ConfigurableTask(Task):
CONFIG = None CONFIG = None
def __init__( def __init__(
self, self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
data_dir=None,
cache_dir=None,
download_mode=None,
config: dict=None
): ):
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -442,11 +457,13 @@ class ConfigurableTask(Task): ...@@ -442,11 +457,13 @@ class ConfigurableTask(Task):
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
# Overwrite configs # Overwrite configs
else: else:
if config != None: if config is not None:
self._config.__dict__.update(config) self._config.__dict__.update(config)
if self._config is None: if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg") raise ValueError(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if self._config.output_type is not None: if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
...@@ -464,16 +481,22 @@ class ConfigurableTask(Task): ...@@ -464,16 +481,22 @@ class ConfigurableTask(Task):
self._higher_is_better = {} self._higher_is_better = {}
for metric_config in self._config.metric_list: for metric_config in self._config.metric_list:
metric_name = metric_config['metric'] metric_name = metric_config["metric"]
aggregation = metric_config['aggregation'] aggregation = metric_config["aggregation"]
higher_is_better = metric_config['higher_is_better'] higher_is_better = metric_config["higher_is_better"]
kwargs = {key: metric_config[key] for key in metric_config if key not in ['metric', 'aggregation', 'higher_is_better']} kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation] self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
if metric_name in METRIC_REGISTRY.keys(): if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[metric_name] self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else: else:
self._higher_is_better[metric_name] = higher_is_better self._higher_is_better[metric_name] = higher_is_better
try: try:
...@@ -481,7 +504,7 @@ class ConfigurableTask(Task): ...@@ -481,7 +504,7 @@ class ConfigurableTask(Task):
self._metric_list[metric_name] = metric_object self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs self._metric_kwargs[metric_name] = kwargs
except Exception as ex: except Exception:
raise Warning( raise Warning(
"{} not found in the evaluate library!".format(metric_name), "{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
...@@ -492,7 +515,7 @@ class ConfigurableTask(Task): ...@@ -492,7 +515,7 @@ class ConfigurableTask(Task):
self._fewshot_docs = None self._fewshot_docs = None
self._filters = [] self._filters = []
if self._config.filter_list != None: if self._config.filter_list is not None:
for filter_config in self._config.filter_list: for filter_config in self._config.filter_list:
for filter_pipeline in filter_config: for filter_pipeline in filter_config:
filter_name = filter_config["name"] filter_name = filter_config["name"]
...@@ -501,39 +524,28 @@ class ConfigurableTask(Task): ...@@ -501,39 +524,28 @@ class ConfigurableTask(Task):
for function in filter_functions: for function in filter_functions:
kwargs = { kwargs = {
key: function[key] for key in function if key != "function" key: function[key] for key in function if key != "function"
} }
components.append([ components.append([function["function"], kwargs])
function['function'],
kwargs filter_pipeline = build_filter_ensemble(filter_name, components)
])
filter_pipeline = build_filter_ensemble(
filter_name,
components
)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [ self._filters = [
build_filter_ensemble( build_filter_ensemble("take_first", [["take_first", None]])
"take_first",
[["take_first", None]]
)
] ]
if self._config.use_prompt is not None: if self._config.use_prompt is not None:
eval_logger.info( eval_logger.info(f"loading prompt {self._config.use_prompt}")
f"loading prompt {self._config.use_prompt}"
)
self.prompt = get_prompt( self.prompt = get_prompt(
self._config.use_prompt, self._config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
self.DATASET_PATH, )
self.DATASET_NAME
)
else: else:
self.prompt = None self.prompt = None
if self.fewshot_docs() != None: if self.fewshot_docs() is not None:
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
def has_training_docs(self): def has_training_docs(self):
if self._config.training_split is not None: if self._config.training_split is not None:
...@@ -566,14 +578,14 @@ class ConfigurableTask(Task): ...@@ -566,14 +578,14 @@ class ConfigurableTask(Task):
return self.dataset[self._config.test_split] return self.dataset[self._config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
if (self._config.num_fewshot > 0) and (self._config.fewshot_split == None): if (self._config.num_fewshot > 0) and (self._config.fewshot_split is None):
eval_logger.warning( eval_logger.warning(
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
return super().fewshot_docs() return super().fewshot_docs()
elif self._config.fewshot_split != None: elif self._config.fewshot_split is not None:
return self.dataset[self._config.fewshot_split] return self.dataset[self._config.fewshot_split]
def should_decontaminate(self): def should_decontaminate(self):
...@@ -600,7 +612,7 @@ class ConfigurableTask(Task): ...@@ -600,7 +612,7 @@ class ConfigurableTask(Task):
doc_to_text = self.prompt doc_to_text = self.prompt
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str: if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc) return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text): elif callable(doc_to_text):
...@@ -630,55 +642,55 @@ class ConfigurableTask(Task): ...@@ -630,55 +642,55 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments=(ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list. # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? # TODO: any cleaner way to do this?
choices = ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc)) choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(choices) for i, choice in enumerate(choices)
] ]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in self._metric_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating # here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice. # in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend( request_list.extend(
[ [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=("", "{}".format(choice)), arguments=("", "{}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(choices) for i, choice in enumerate(choices)
] ]
) )
return request_list return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, self._config.delimiter) arguments = (ctx, self._config.delimiter)
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
doc=doc, )
arguments=arguments,
idx=0,
**kwargs
)
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -697,11 +709,20 @@ class ConfigurableTask(Task): ...@@ -697,11 +709,20 @@ class ConfigurableTask(Task):
"bits_per_byte": (loglikelihood, bytes_), "bits_per_byte": (loglikelihood, bytes_),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy lls = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc)) choices = ast.literal_eval(
if 2 * len(choices) == len(lls) and "acc_mutual_info" in self._metric_list.keys(): utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_list.keys()
):
# then we are doing mutual info. # then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods # this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2] lls_unconditional = lls[1::2]
...@@ -722,12 +743,16 @@ class ConfigurableTask(Task): ...@@ -722,12 +743,16 @@ class ConfigurableTask(Task):
if "exact_match" in self._metric_list.keys(): if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [res[1] for res in results] # take only the `is_greedy` results is_greedy = [
is_greedy = is_greedy[gold] # take value for the gold answer res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy) result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in self._metric_list.keys():
lls_mutual_info = [ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)] lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
...@@ -740,15 +765,14 @@ class ConfigurableTask(Task): ...@@ -740,15 +765,14 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_list.keys(), results): for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute( _dict = self._metric_list[key].compute(
references=[gold], references=[gold], predictions=[result], **self._metric_kwargs[key]
predictions=[result],
**self._metric_kwargs[key]
) )
result_dict[key] = _dict[key] result_dict[key] = _dict[key]
else: else:
raise ValueError(f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", raise ValueError(
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'" f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'",
) )
return result_dict return result_dict
...@@ -769,17 +793,21 @@ class MultipleChoiceTask(Task): ...@@ -769,17 +793,21 @@ class MultipleChoiceTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
# TODO: add mutual info here? # TODO: add mutual info here?
return [Instance( return [
Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, " {}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(doc["choices"])] for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc, results): def process_results(self, doc, results):
results = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
gold = doc["gold"] gold = doc["gold"]
acc = 1.0 if np.argmax(results) == gold else 0.0 acc = 1.0 if np.argmax(results) == gold else 0.0
...@@ -815,9 +843,7 @@ class PerplexityTask(Task): ...@@ -815,9 +843,7 @@ class PerplexityTask(Task):
assert k == 0 assert k == 0
return [] return []
def fewshot_context( def fewshot_context(self, doc, num_fewshot, rnd=None):
self, doc, num_fewshot, rnd=None
):
assert ( assert (
num_fewshot == 0 num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks." ), "The number of fewshot examples must be 0 for perplexity tasks."
...@@ -846,7 +872,13 @@ class PerplexityTask(Task): ...@@ -846,7 +872,13 @@ class PerplexityTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
assert not ctx assert not ctx
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), idx=0, **kwargs) return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(self.doc_to_target(doc),),
idx=0,
**kwargs,
)
def process_results(self, doc, results): def process_results(self, doc, results):
(loglikelihood,) = results (loglikelihood,) = results
......
...@@ -10,7 +10,12 @@ import lm_eval.api.metrics ...@@ -10,7 +10,12 @@ import lm_eval.api.metrics
import lm_eval.tasks import lm_eval.tasks
import lm_eval.models import lm_eval.models
from lm_eval.utils import positional_deprecated, run_task_tests, make_table, get_git_commit_hash from lm_eval.utils import (
positional_deprecated,
run_task_tests,
make_table,
get_git_commit_hash,
)
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -127,20 +132,20 @@ def evaluate( ...@@ -127,20 +132,20 @@ def evaluate(
Dictionary of results Dictionary of results
""" """
decontaminate = decontamination_ngrams_path is not None # decontaminate = decontamination_ngrams_path is not None
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) # requests_origin = collections.defaultdict(list)
docs = {} # docs = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
# rnd = random.Random() # rnd = random.Random()
...@@ -150,9 +155,13 @@ def evaluate( ...@@ -150,9 +155,13 @@ def evaluate(
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): # for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task.build_all_requests(limit=limit) task.build_all_requests(limit=limit)
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py reqtype = (
requests[reqtype].extend(task.instances) "loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
for reqtype, reqs in requests.items(): for reqtype, reqs in requests.items():
...@@ -161,7 +170,7 @@ def evaluate( ...@@ -161,7 +170,7 @@ def evaluate(
cloned_reqs = [] cloned_reqs = []
for req in reqs: for req in reqs:
cloned_reqs.extend([req] * req.repeats) cloned_reqs.extend([req] * req.repeats)
# run requests through model # run requests through model
resps = getattr(lm, reqtype)(cloned_reqs) resps = getattr(lm, reqtype)(cloned_reqs)
...@@ -175,7 +184,7 @@ def evaluate( ...@@ -175,7 +184,7 @@ def evaluate(
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
# TODO: make metric configurable, add metric registry # TODO: make metric configurable, add metric registry
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
...@@ -183,11 +192,17 @@ def evaluate( ...@@ -183,11 +192,17 @@ def evaluate(
# calculate values for each filter setup (TODO: make getting list of keys cleaner) # calculate values for each filter setup (TODO: make getting list of keys cleaner)
# TODO: make it possible to use a different metric per key # TODO: make it possible to use a different metric per key
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
for doc_id, doc in enumerate(itertools.islice(task.test_docs(), 0, limit) if task.has_test_docs() else task.validation_docs()): for doc_id, doc in enumerate(
itertools.islice(task.test_docs(), 0, limit)
if task.has_test_docs()
else task.validation_docs()
):
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx) requests.sort(key=lambda x: x.idx)
metrics = task.process_results(doc, [req.filtered_resps[key] for req in requests]) metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
...@@ -195,7 +210,9 @@ def evaluate( ...@@ -195,7 +210,9 @@ def evaluate(
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[metric](items) results[task_name][metric + " - filter=" + key] = task.aggregation()[metric](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
......
...@@ -6,10 +6,10 @@ from . import extraction ...@@ -6,10 +6,10 @@ from . import extraction
FILTER_REGISTRY = { FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter, "take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter, "regex": extraction.RegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
#"arg_max": selection.ArgMaxFilter, # "arg_max": selection.ArgMaxFilter,
} }
...@@ -24,11 +24,11 @@ def build_filter_ensemble(filter_name, components): ...@@ -24,11 +24,11 @@ def build_filter_ensemble(filter_name, components):
filters = [] filters = []
for (function, kwargs) in components: for (function, kwargs) in components:
if kwargs == None: if kwargs is None:
f = get_filter(function)() f = get_filter(function)()
else: else:
# create a filter given its name in the registry # create a filter given its name in the registry
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
# add the filter as a pipeline step # add the filter as a pipeline step
filters.append(f) filters.append(f)
......
...@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter ...@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter
class DecontaminationFilter(Filter): class DecontaminationFilter(Filter):
""" """
A filter which evaluates A filter which evaluates
""" """
name = "track_decontamination" name = "track_decontamination"
...@@ -12,7 +12,7 @@ class DecontaminationFilter(Filter): ...@@ -12,7 +12,7 @@ class DecontaminationFilter(Filter):
def __init__(self, path): def __init__(self, path):
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id) should further cache result on a given (task_name, doc_id)
""" """
self._decontam_results = None self._decontam_results = None
...@@ -21,4 +21,4 @@ class DecontaminationFilter(Filter): ...@@ -21,4 +21,4 @@ class DecontaminationFilter(Filter):
""" """
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
""" """
pass pass
\ No newline at end of file
...@@ -4,10 +4,7 @@ from lm_eval.api.filter import Filter ...@@ -4,10 +4,7 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """
"""
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"): def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
""" """
...@@ -20,7 +17,7 @@ class RegexFilter(Filter): ...@@ -20,7 +17,7 @@ class RegexFilter(Filter):
def apply(self, resps): def apply(self, resps):
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
# independently (and keep them a list.) # independently (and keep them a list.)
def filter_set(inst): def filter_set(inst):
......
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
class TakeFirstFilter:
class TakeFirstFilter:
def __init__(self): def __init__(self):
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
...@@ -11,4 +11,4 @@ class TakeFirstFilter: ...@@ -11,4 +11,4 @@ class TakeFirstFilter:
""" """
Assuming each entry of `resps` is a list of model responses, we discard all but the first response. Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
""" """
return map(lambda r: r[0], resps) return map(lambda r: r[0], resps)
\ No newline at end of file
import logging import logging
logging.basicConfig( logging.basicConfig(
format='%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt='%Y-%m-%d:%H:%M:%S', datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO level=logging.INFO,
) )
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
\ No newline at end of file
...@@ -111,7 +111,11 @@ class HFLM(LM): ...@@ -111,7 +111,11 @@ class HFLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, max_length=max_length, pad_token_id=eos_token_id, eos_token_id=eos_token_id, do_sample=False context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
) )
def loglikelihood(self, requests): def loglikelihood(self, requests):
......
...@@ -2,46 +2,44 @@ from lm_eval.logger import eval_logger ...@@ -2,46 +2,44 @@ from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both? # TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY = { PROMPT_REGISTRY = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:" "q-newline-a": "Q: {{question}}\nA:",
}, },
} }
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None): def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name # unpack prompt name
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
eval_logger.info( eval_logger.info(f"Loading prompt from {category_name}")
f"Loading prompt from {category_name}"
)
if category_name == "promptsource": if category_name == "promptsource":
try: try:
# prompts = DatasetTemplates(dataset_name, dataset_path) # prompts = DatasetTemplates(dataset_name, dataset_path)
if subset_name == None: if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name) prompts = DatasetTemplates(dataset_name=dataset_name)
else: else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name) prompts = DatasetTemplates(
except: dataset_name=dataset_name, subset_name=subset_name
raise ValueError(
f"{dataset_name} and {subset_name} not found"
) )
except Exception:
raise ValueError(f"{dataset_name} and {subset_name} not found")
if prompt_name in prompts.all_template_names: if prompt_name in prompts.all_template_names:
return prompts[prompt_name] return prompts[prompt_name]
else: else:
raise ValueError( raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}" f"{prompt_name} not in prompt list {prompts.all_template_names}"
) )
else: else:
try: try:
return PROMPT_REGISTRY[category_name][prompt_name] return PROMPT_REGISTRY[category_name][prompt_name]
except: except Exception:
raise ValueError( raise ValueError(
f"expected only a single `:` as separator between \ f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead" prompt category and name, but got `{prompt_id}` instead"
) )
...@@ -10,8 +10,8 @@ from lm_eval.api.register import ( ...@@ -10,8 +10,8 @@ from lm_eval.api.register import (
register_task, register_task,
register_group, register_group,
task_registry, task_registry,
group_registry group_registry,
) )
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config):
...@@ -28,20 +28,19 @@ for root, subdirs, file_list in os.walk(task_dir): ...@@ -28,20 +28,19 @@ for root, subdirs, file_list in os.walk(task_dir):
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
SubClass = type( SubClass = type(
config['task']+'ConfigurableTask', config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
{'CONFIG': TaskConfig(**config)} {"CONFIG": TaskConfig(**config)},
) )
if 'task' in config: if "task" in config:
task_name = "{}:{}".format( task_name = "{}:{}".format(
get_task_name_from_config(config), get_task_name_from_config(config), config["task"]
config['task'] )
)
register_task(task_name)(SubClass) register_task(task_name)(SubClass)
if 'group' in config: if "group" in config:
for group in config['group']: for group in config["group"]:
register_group(group)(SubClass) register_group(group)(SubClass)
except Exception as err: except Exception as err:
print(f"Unexpected {err=}, {type(err)=}") print(f"Unexpected {err=}, {type(err)=}")
...@@ -50,6 +49,7 @@ TASK_REGISTRY = task_registry ...@@ -50,6 +49,7 @@ TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys())) ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
def get_task(task_name, config): def get_task(task_name, config):
try: try:
return TASK_REGISTRY[task_name](config=config) return TASK_REGISTRY[task_name](config=config)
...@@ -90,19 +90,15 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -90,19 +90,15 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: get_task( task_name: get_task(task_name=task_name, config=config),
task_name=task_name, config=config }
)
}
else: else:
task_name = task_element task_name = task_element
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: get_task( task_name: get_task(task_name=task_element, config=config),
task_name=task_element, config=config }
)
}
elif isinstance(task_element, dict): elif isinstance(task_element, dict):
task_element.update(config) task_element.update(config)
...@@ -110,22 +106,22 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -110,22 +106,22 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
**task_name_from_config_dict, **task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask( get_task_name_from_config(task_element): ConfigurableTask(
config=task_element config=task_element
) ),
} }
elif isinstance(task_element, Task): elif isinstance(task_element, Task):
task_name_from_object_dict = { task_name_from_object_dict = {
**task_name_from_object_dict, **task_name_from_object_dict,
get_task_name_from_object(task_element): task_element get_task_name_from_object(task_element): task_element,
} }
# task_name_from_registry_dict = { # task_name_from_registry_dict = {
# task_name: get_task( # task_name: get_task(
# task_name=task_name, # task_name=task_name,
# task_config=config # task_config=config
# ) # )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name] # for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY) # if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# } # }
# task_name_from_config_dict = { # task_name_from_config_dict = {
...@@ -142,11 +138,11 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -142,11 +138,11 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
# if isinstance(task_object, Task) # if isinstance(task_object, Task)
# } # }
assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_from_registry_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
return { return {
**task_name_from_registry_dict, **task_name_from_registry_dict,
**task_name_from_config_dict, **task_name_from_config_dict,
**task_name_from_object_dict, **task_name_from_object_dict,
} }
...@@ -12,6 +12,7 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi ...@@ -12,6 +12,7 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc Homepage: https://allenai.org/data/arc
""" """
from lm_eval import utils
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask from lm_eval.api.task import MultipleChoiceTask
...@@ -27,6 +28,7 @@ _CITATION = """ ...@@ -27,6 +28,7 @@ _CITATION = """
} }
""" """
@register_group("arc") @register_group("arc")
@register_task("arc_easy") @register_task("arc_easy")
class ARCEasy(MultipleChoiceTask): class ARCEasy(MultipleChoiceTask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment