Commit 66c58194 authored by lintangsutawika's avatar lintangsutawika
Browse files

can process doc_to_text and doc_to_target as function

parent 51b795cf
...@@ -12,8 +12,10 @@ import functools ...@@ -12,8 +12,10 @@ import functools
import datasets import datasets
import numpy as np import numpy as np
from lm_eval import utils from typing import Union
from collections.abc import Callable
from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
...@@ -42,8 +44,8 @@ class TaskConfig(dict): ...@@ -42,8 +44,8 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = "" template_aliases: str = ""
doc_to_text: str = "" doc_to_text: Union[Callable, str] = None
doc_to_target: str = "" doc_to_target: Union[Callable, str] = None
num_fewshot: int = 0 num_fewshot: int = 0
...@@ -66,8 +68,11 @@ class TaskConfig(dict): ...@@ -66,8 +68,11 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
# field names in prompt # field names in prompt
self.doc_to_text = self.template_aliases + self.doc_to_text if type(self.doc_to_text) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_text = self.template_aliases + self.doc_to_text
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set # set "task_name" metadata field based on the "primary" name set
if self.names: if self.names:
...@@ -439,9 +444,11 @@ class ConfigurableTask(Task): ...@@ -439,9 +444,11 @@ class ConfigurableTask(Task):
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None: if self._config.dataset_path is not None:
print(self._config.dataset_path)
self.DATASET_PATH = self._config.dataset_path self.DATASET_PATH = self._config.dataset_path
if self._config.dataset_name is not None: if self._config.dataset_name is not None:
print(self._config.dataset_name)
self.DATASET_NAME = self._config.dataset_name self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None: if self._config.metric_list is not None:
...@@ -546,10 +553,24 @@ class ConfigurableTask(Task): ...@@ -546,10 +553,24 @@ class ConfigurableTask(Task):
doc_to_text = get_prompt(self._config.use_prompt) doc_to_text = get_prompt(self._config.use_prompt)
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self._config.doc_to_text
return utils.apply_template(doc_to_text, doc)
print(doc_to_text)
if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc)
elif type(doc_to_text) == Callable:
return doc_to_text(doc)
else:
raise TypeError
def doc_to_target(self, doc): def doc_to_target(self, doc):
return utils.apply_template(self._config.doc_to_target, doc) doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif type(doc_to_target) == Callable:
return doc_to_target(doc)
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment