Commit 66c58194 authored by lintangsutawika's avatar lintangsutawika
Browse files

can process doc_to_text and doc_to_target as function

parent 51b795cf
......@@ -12,8 +12,10 @@ import functools
import datasets
import numpy as np
from lm_eval import utils
from typing import Union
from collections.abc import Callable
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import (
......@@ -42,8 +44,8 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = ""
doc_to_text: str = ""
doc_to_target: str = ""
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
num_fewshot: int = 0
......@@ -66,8 +68,11 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
self.doc_to_text = self.template_aliases + self.doc_to_text
self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
......@@ -439,9 +444,11 @@ class ConfigurableTask(Task):
self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None:
print(self._config.dataset_path)
self.DATASET_PATH = self._config.dataset_path
if self._config.dataset_name is not None:
print(self._config.dataset_name)
self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None:
......@@ -546,10 +553,24 @@ class ConfigurableTask(Task):
doc_to_text = get_prompt(self._config.use_prompt)
else:
doc_to_text = self._config.doc_to_text
return utils.apply_template(doc_to_text, doc)
print(doc_to_text)
if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc)
elif type(doc_to_text) == Callable:
return doc_to_text(doc)
else:
raise TypeError
def doc_to_target(self, doc):
return utils.apply_template(self._config.doc_to_target, doc)
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif type(doc_to_target) == Callable:
return doc_to_target(doc)
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment