Commit 4e72f165 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixes for multiple_choice

parent bafaef26
...@@ -73,7 +73,7 @@ class TaskConfig(dict): ...@@ -73,7 +73,7 @@ class TaskConfig(dict):
repeats: int = 1 repeats: int = 1
metric_list: str = None metric_list: str = None
gold_alias: str = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
delimiter: str = "\n\n" delimiter: str = "\n\n"
...@@ -95,7 +95,7 @@ class TaskConfig(dict): ...@@ -95,7 +95,7 @@ class TaskConfig(dict):
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs or self.output_type == "greedy_until":
assert ( assert (
...@@ -737,10 +737,11 @@ class ConfigurableTask(Task): ...@@ -737,10 +737,11 @@ class ConfigurableTask(Task):
def gold_alias(self, doc): def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a # TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref. # processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias: if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias doc_to_target = self._config.gold_alias
else: else:
doc_to_target = self._config.doc_to_target # doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
...@@ -842,7 +843,11 @@ class ConfigurableTask(Task): ...@@ -842,7 +843,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc)) if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls) pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
......
...@@ -9,7 +9,8 @@ validation_split: validation ...@@ -9,7 +9,8 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: " {{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment