Commit 18fde066 authored by Baber's avatar Baber
Browse files

for multiple_inputs use `doc_to_text: list[str]``

parent 85970a3e
...@@ -18,6 +18,7 @@ from typing import ( ...@@ -18,6 +18,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Union, Union,
cast,
) )
import datasets import datasets
...@@ -940,16 +941,16 @@ class ConfigurableTask(Task): ...@@ -940,16 +941,16 @@ class ConfigurableTask(Task):
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if self.multiple_inputs: if self.multiple_inputs:
# we require: # we require:
# doc_to_text: int # doc_to_text: list
# doc_to_choice: list # doc_to_choice: list
# doc_to_target: str # doc_to_target: int
# e.g. text: 1, choice: [Maria was better than Sarah, Sarah was better than Sarah] # e.g. text: [Maria was better than Sarah, Sarah was better than Sarah], choice: [so she was envious]
# target: so she was envious # target: 0
assert isinstance(test_text, int), ( assert isinstance(test_text, list), (
f"[{self.config.task}] doc_to_text must return int for multiple inputs" f"[{self.config.task}] doc_to_text must return list for multiple inputs"
) )
assert isinstance(test_target, str), ( assert isinstance(test_target, int), (
f"[{self.config.task}] doc_to_target must return str for multiple inputs" f"[{self.config.task}] doc_to_target must return int label for multiple inputs"
) )
assert self.config.output_type != "generate_until", ( assert self.config.output_type != "generate_until", (
f"[{self.config.task}] Only multiple-choice tasks can be used with multiple inputs" f"[{self.config.task}] Only multiple-choice tasks can be used with multiple inputs"
...@@ -972,10 +973,10 @@ class ConfigurableTask(Task): ...@@ -972,10 +973,10 @@ class ConfigurableTask(Task):
test_target = test_choice[test_target] test_target = test_choice[test_target]
for choice in test_choice: for choice in test_choice:
choice_has_whitespace = choice[0].isspace() choice_has_whitespace, delimiter_has_whitespace = (
delimiter_has_whitespace = ( choice[0].isspace(),
self.config.target_delimiter.rstrip() self.config.target_delimiter.rstrip()
!= self.config.target_delimiter != self.config.target_delimiter,
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -1307,14 +1308,6 @@ class ConfigurableTask(Task): ...@@ -1307,14 +1308,6 @@ class ConfigurableTask(Task):
return ast.literal_eval(text_string) return ast.literal_eval(text_string)
elif callable(doc_to_text): elif callable(doc_to_text):
return doc_to_text(doc) return doc_to_text(doc)
# Used when applying a Promptsource template
# elif hasattr(doc_to_text, "apply"):
# applied_prompt = doc_to_text.apply(doc)
# if len(applied_prompt) == 2:
# return applied_prompt[0]
# else:
# eval_logger.warning("Applied prompt returns empty string")
# return self.config.fewshot_delimiter
else: else:
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
...@@ -1344,14 +1337,6 @@ class ConfigurableTask(Task): ...@@ -1344,14 +1337,6 @@ class ConfigurableTask(Task):
elif isinstance(doc_to_target, list): elif isinstance(doc_to_target, list):
# ["{{field}}", "{{field}}"] # ["{{field}}", "{{field}}"]
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
# Used when applying a Promptsource template
# elif hasattr(doc_to_target, "apply"):
# applied_prompt = doc_to_target.apply(doc)
# if len(applied_prompt) == 2:
# return applied_prompt[1]
# else:
# eval_logger.warning("Applied prompt returns empty string")
# return self.config.fewshot_delimiter
else: else:
raise TypeError raise TypeError
...@@ -1374,8 +1359,6 @@ class ConfigurableTask(Task): ...@@ -1374,8 +1359,6 @@ class ConfigurableTask(Task):
return list(doc_to_choice.values()) return list(doc_to_choice.values())
elif callable(doc_to_choice): elif callable(doc_to_choice):
return doc_to_choice(doc) return doc_to_choice(doc)
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc)
else: else:
raise TypeError raise TypeError
...@@ -1573,7 +1556,11 @@ class ConfigurableTask(Task): ...@@ -1573,7 +1556,11 @@ class ConfigurableTask(Task):
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc) choices = (
self.doc_to_choice(doc)
if not self.multiple_inputs
else cast(list[str], self.doc_to_text(doc))
)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if ( if (
...@@ -1592,11 +1579,7 @@ class ConfigurableTask(Task): ...@@ -1592,11 +1579,7 @@ class ConfigurableTask(Task):
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
gold = backup = ( gold = backup = self.doc_to_target(doc)
self.doc_to_target(doc)
if not self.multiple_inputs
else self.doc_to_text(doc)
)
if isinstance(gold, list): if isinstance(gold, list):
gold = [validate_index(g, len(choices)) for g in gold] gold = [validate_index(g, len(choices)) for g in gold]
......
def doc_to_text(doc): def doc_to_target(doc) -> int:
answer_to_num = {"1": 0, "2": 1} answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc["answer"]] return answer_to_num[doc["answer"]]
def doc_to_target(doc): def doc_to_choice(doc) -> list[str]:
idx = doc["sentence"].index("_") + 1 idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip() return [doc["sentence"][idx:].strip()]
def doc_to_choice(doc): def doc_to_text(doc) -> list[str]:
idx = doc["sentence"].index("_") idx = doc["sentence"].index("_")
options = [doc["option1"], doc["option2"]] options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options] return [doc["sentence"][:idx] + opt for opt in options]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment