Commit cea713dc authored by lintangsutawika's avatar lintangsutawika
Browse files

can process looglikelihood requests

parent 84191b83
......@@ -11,6 +11,7 @@ import numpy as np
from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils
......@@ -45,6 +46,8 @@ class TaskConfig(dict):
filters: str = None #TODO: need to make this typehint `list`?
normalization: str = None # TODO: add length-normalization of various types, mutual info
stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen.
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
def __post_init__(self):
# allow user-specified aliases so that users can
......@@ -379,6 +382,10 @@ class ConfigurableTask(Task):
):
self._config = TaskConfig(**config)
if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None:
self.DATASET_PATH = self._config.dataset_path
......@@ -392,17 +399,18 @@ class ConfigurableTask(Task):
self._metric_kwargs = {}
for metric_config in self._config.metric_list:
metric_name = metric_config['name']
metric_name = metric_config['metric']
aggregation = metric_config['aggregation']
higher_is_better = metric_config['higher_is_better']
kwargs = {key: metric_config[key] for key in metric_config if key not in ['name', 'aggregation', 'higher_is_better']}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._higher_is_better[metric_name] = higher_is_better
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[metric_name]
else:
self._higher_is_better[metric_name] = higher_is_better
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
......@@ -454,6 +462,13 @@ class ConfigurableTask(Task):
if self._config.test_split is not None:
return self.dataset[self._config.test_split]
def should_decontaminate(self):
return self._config.should_decontaminate
def doc_to_decontamination_query(self, doc):
if self._config.should_decontaminate:
return utils.apply_template(self._config.doc_to_decontamination_query, doc)
def _process_doc(self, doc):
"""
Override this to process (detokenize, strip, replace, etc.) individual
......@@ -473,15 +488,15 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs):
if self.output_type == "loglikelihood":
if self.OUTPUT_TYPE == "loglikelihood":
arguments=(ctx, self.doc_to_target(doc))
elif self.output_type == "loglikelihood_rolling":
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),)
elif self.output_type == "greedy_until":
elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, "\n\n")
return Instance(
request_type=self.output_type,
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
**kwargs
......@@ -489,28 +504,35 @@ class ConfigurableTask(Task):
def process_results(self, doc, results):
if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias]
else:
gold = self.doc_to_target(doc)
result_dict = {}
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key](
references=[gold],
predictions=[result],
)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
pass
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias]
else:
gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute(
references=[gold],
predictions=[result],
**self._metric_kwargs[key]
)
result_dict[key] = _dict[key]
result_dict[key] = _dict[key]
return result_dict
def aggregation(self):
return self._aggregation_list
def higher_is_better(self):
return self._higher_is_better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment