Commit cea713dc authored by lintangsutawika's avatar lintangsutawika
Browse files

can process looglikelihood requests

parent 84191b83
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
from typing import List, Union from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils from lm_eval import utils
...@@ -45,6 +46,8 @@ class TaskConfig(dict): ...@@ -45,6 +46,8 @@ class TaskConfig(dict):
filters: str = None #TODO: need to make this typehint `list`? filters: str = None #TODO: need to make this typehint `list`?
normalization: str = None # TODO: add length-normalization of various types, mutual info normalization: str = None # TODO: add length-normalization of various types, mutual info
stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen. stop_sequences: list = None # TODO: allow passing of stop sequences to greedy gen.
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
def __post_init__(self): def __post_init__(self):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
...@@ -379,6 +382,10 @@ class ConfigurableTask(Task): ...@@ -379,6 +382,10 @@ class ConfigurableTask(Task):
): ):
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None: if self._config.dataset_path is not None:
self.DATASET_PATH = self._config.dataset_path self.DATASET_PATH = self._config.dataset_path
...@@ -392,17 +399,18 @@ class ConfigurableTask(Task): ...@@ -392,17 +399,18 @@ class ConfigurableTask(Task):
self._metric_kwargs = {} self._metric_kwargs = {}
for metric_config in self._config.metric_list: for metric_config in self._config.metric_list:
metric_name = metric_config['name'] metric_name = metric_config['metric']
aggregation = metric_config['aggregation'] aggregation = metric_config['aggregation']
higher_is_better = metric_config['higher_is_better'] higher_is_better = metric_config['higher_is_better']
kwargs = {key: metric_config[key] for key in metric_config if key not in ['name', 'aggregation', 'higher_is_better']} kwargs = {key: metric_config[key] for key in metric_config if key not in ['name', 'aggregation', 'higher_is_better']}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation] self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._higher_is_better[metric_name] = higher_is_better
if metric_name in METRIC_REGISTRY.keys(): if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[metric_name]
else: else:
self._higher_is_better[metric_name] = higher_is_better
try: try:
metric_object = evaluate.load(metric_name) metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object self._metric_list[metric_name] = metric_object
...@@ -454,6 +462,13 @@ class ConfigurableTask(Task): ...@@ -454,6 +462,13 @@ class ConfigurableTask(Task):
if self._config.test_split is not None: if self._config.test_split is not None:
return self.dataset[self._config.test_split] return self.dataset[self._config.test_split]
def should_decontaminate(self):
return self._config.should_decontaminate
def doc_to_decontamination_query(self, doc):
if self._config.should_decontaminate:
return utils.apply_template(self._config.doc_to_decontamination_query, doc)
def _process_doc(self, doc): def _process_doc(self, doc):
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
...@@ -473,15 +488,15 @@ class ConfigurableTask(Task): ...@@ -473,15 +488,15 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.output_type == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments=(ctx, self.doc_to_target(doc)) arguments=(ctx, self.doc_to_target(doc))
elif self.output_type == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),) arguments=(self.doc_to_target(doc),)
elif self.output_type == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, "\n\n") arguments=(ctx, "\n\n")
return Instance( return Instance(
request_type=self.output_type, request_type=self.OUTPUT_TYPE,
doc=doc, doc=doc,
arguments=arguments, arguments=arguments,
**kwargs **kwargs
...@@ -489,16 +504,25 @@ class ConfigurableTask(Task): ...@@ -489,16 +504,25 @@ class ConfigurableTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
result_dict = {}
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
pass
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias] gold = doc[self._config.gold_alias]
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result_dict = {}
for key, result in zip(self._metric_list.keys(), results): for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key]( _dict = self._metric_list[key].compute(
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_kwargs[key]
) )
result_dict[key] = _dict[key] result_dict[key] = _dict[key]
...@@ -506,11 +530,9 @@ class ConfigurableTask(Task): ...@@ -506,11 +530,9 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def aggregation(self): def aggregation(self):
return self._aggregation_list return self._aggregation_list
def higher_is_better(self): def higher_is_better(self):
return self._higher_is_better return self._higher_is_better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment