Unverified Commit ff4b2fac authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #598 from EleutherAI/patch-multiple-choice

fixes for multiple_choice
parents c87fc788 b4dd86e3
...@@ -73,7 +73,7 @@ class TaskConfig(dict): ...@@ -73,7 +73,7 @@ class TaskConfig(dict):
repeats: int = 1 repeats: int = 1
metric_list: str = None metric_list: str = None
gold_alias: str = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
delimiter: str = "\n\n" delimiter: str = "\n\n"
...@@ -95,7 +95,7 @@ class TaskConfig(dict): ...@@ -95,7 +95,7 @@ class TaskConfig(dict):
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs or self.output_type == "greedy_until":
assert ( assert (
...@@ -737,10 +737,11 @@ class ConfigurableTask(Task): ...@@ -737,10 +737,11 @@ class ConfigurableTask(Task):
def gold_alias(self, doc): def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a # TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref. # processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias: if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias doc_to_target = self._config.gold_alias
else: else:
doc_to_target = self._config.doc_to_target # doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
...@@ -842,7 +843,11 @@ class ConfigurableTask(Task): ...@@ -842,7 +843,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc)) if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls) pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
......
...@@ -9,7 +9,8 @@ validation_split: validation ...@@ -9,7 +9,8 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: " {{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -15,7 +15,7 @@ setuptools.setup( ...@@ -15,7 +15,7 @@ setuptools.setup(
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
# required to include yaml files in pip installation # required to include yaml files in pip installation
package_data={ package_data={
"lm_eval": ["**/*.yaml"], "lm_eval": ["**/*.yaml"],
"examples": ["**/*.yaml"], "examples": ["**/*.yaml"],
}, },
include_package_data=True, include_package_data=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment