Unverified Commit ff4b2fac authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #598 from EleutherAI/patch-multiple-choice

fixes for multiple_choice
parents c87fc788 b4dd86e3
......@@ -73,7 +73,7 @@ class TaskConfig(dict):
repeats: int = 1
metric_list: str = None
gold_alias: str = None
gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
delimiter: str = "\n\n"
......@@ -95,7 +95,7 @@ class TaskConfig(dict):
self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target
self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until":
assert (
......@@ -737,10 +737,11 @@ class ConfigurableTask(Task):
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias:
if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias
else:
doc_to_target = self._config.doc_to_target
# doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
......@@ -842,7 +843,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc))
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(
......
......@@ -9,7 +9,8 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: " {{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......
......@@ -15,7 +15,7 @@ setuptools.setup(
packages=setuptools.find_packages(),
# required to include yaml files in pip installation
package_data={
"lm_eval": ["**/*.yaml"],
"lm_eval": ["**/*.yaml"],
"examples": ["**/*.yaml"],
},
include_package_data=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment