Commit ce5ec854 authored by lintangsutawika's avatar lintangsutawika
Browse files

doc_to_choice can run when a promptsource template is defined

parent 5288813a
...@@ -315,20 +315,6 @@ class Task(abc.ABC): ...@@ -315,20 +315,6 @@ class Task(abc.ABC):
""" """
return doc return doc
def doc_to_choice(self, doc):
if self._config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config")
elif type(self._config.doc_to_choice) == list:
return self._config.doc_to_choice
elif type(self._config.doc_to_choice) == dict:
return list(self._config.doc_to_choice.values())
elif type(self._config.doc_to_choice) == str:
return ast.literal_eval(
utils.apply_template(self._config.doc_to_choice, doc)
)
else:
return self._config.doc_to_choice(doc)
@property @property
def instances(self): def instances(self):
"""After calling `task.build_all_requests()`, tasks """After calling `task.build_all_requests()`, tasks
...@@ -789,6 +775,28 @@ class ConfigurableTask(Task): ...@@ -789,6 +775,28 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc):
if self.prompt is not None:
doc_to_choice = self.prompt
elif doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config")
else:
doc_to_choice = self._config.doc_to_choice
if type(doc_to_choice) == str:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif type(doc_to_choice) == list:
return doc_to_choice
elif type(doc_to_choice) == dict:
return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc)
else:
raise TypeError
def gold_alias(self, doc): def gold_alias(self, doc):
# returns a version of the gold target answer to a document, # returns a version of the gold target answer to a document,
# which should be passed into metric for scoring as the ground truth. # which should be passed into metric for scoring as the ground truth.
......
group:
- glue-promptsource
task: qnli
dataset_path: glue
dataset_name: qnli
output_type: multiple_choice
training_split: train
validation_split: validation
use_prompt: "promptsource:have all you need"
metric_list:
- metric: acc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment