Commit 61d11b8a authored by Baber's avatar Baber
Browse files

require multiple_inputs and multiple_targets to be explicitly set in taskconfig

parent 57b86c47
......@@ -39,6 +39,7 @@ from lm_eval.api.registry import (
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
from lm_eval.utils import validate_index
ALL_OUTPUT_TYPES = [
......@@ -96,6 +97,8 @@ class TaskConfig(dict):
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
multiple_inputs: bool = False
multiple_targets: bool = False
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -767,6 +770,12 @@ class ConfigurableTask(Task):
)
self.OUTPUT_TYPE = self.config.output_type
self.multiple_targets = self.config.multiple_targets
self.multiple_inputs = self.config.multiple_inputs
assert not (self.multiple_targets and self.multiple_inputs), (
"Cannot have both multiple_targets and multiple_inputs"
)
if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True
......@@ -923,50 +932,54 @@ class ConfigurableTask(Task):
# Test One Doc
self.features = list(self.task_docs.features.keys())
self.multiple_input = 0
self.multiple_target = 0
test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
if self.multiple_inputs:
# we require:
# doc_to_text: int
# doc_to_choice: list
# doc_to_target: str
# e.g. text: 1, choice: [Maria was better than Sarah, Sarah was better than Sarah]
# target: so she was envious
assert isinstance(test_text, int), (
"doc_to_text must return int for multiple inputs"
)
assert isinstance(test_target, str), (
"doc_to_target must return str for multiple inputs"
)
assert self.config.output_type != "generate_until", (
"Only multiple-choice tasks can be used with multiple inputs"
)
test_text = test_choice[0]
elif self.multiple_targets:
# we require:
# doc_to_text: str
# doc_to_choice: list
# doc_to_target: list
assert isinstance(test_target, (list, tuple)), (
"doc_to_target must be an iterable for multiple targets"
)
test_target = test_target[0]
else:
num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
assert isinstance(test_target, int), (
"doc_to_target must return int for multiple choices"
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
if test_choice is not None:
check_choices = test_choice
else:
check_choices = [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
assert hasattr(test_choice, "__iter__") and not isinstance(
test_choice, (str, bytes)
), "doc_to_choice must be an iterable!"
for choice in test_choice:
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
True
if self.config.target_delimiter.rstrip()
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
else False
)
if delimiter_has_whitespace and choice_has_whitespace:
......@@ -1162,7 +1175,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
if self.multiple_inputs:
# TODO: append prefill?
if not labeled_examples:
return ""
......@@ -1222,7 +1235,7 @@ class ConfigurableTask(Task):
if gen_prefix is not None
else ""
)
if self.multiple_input:
if self.multiple_inputs:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example + prefix
......@@ -1371,7 +1384,7 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features:
return doc[doc_to_choice]
else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
return utils.apply_template(doc_to_choice, doc)
elif isinstance(doc_to_choice, list):
return doc_to_choice
elif isinstance(doc_to_choice, dict):
......@@ -1454,7 +1467,7 @@ class ConfigurableTask(Task):
target_delimiter = self.config.target_delimiter
if apply_chat_template:
target_delimiter = ""
if self.multiple_input:
if self.multiple_inputs:
# If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template
cont = self.doc_to_target(doc)
......@@ -1595,24 +1608,22 @@ class ConfigurableTask(Task):
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
if self.multiple_input:
gold = self.doc_to_text(doc)
else:
gold = self.doc_to_target(doc)
gold = (
self.doc_to_text(doc)
if not self.multiple_inputs
else self.doc_to_text(doc)
)
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
gold = [validate_index(g, len(choices)) for g in gold]
gold_index_error = -100 in gold
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
gold = validate_index(gold, len(choices))
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
gold_index_error = gold == -100
if gold_index_error:
eval_logger.warning(
......@@ -1620,7 +1631,7 @@ class ConfigurableTask(Task):
f"Sample:\n\n{doc}\n\n"
)
if self.multiple_target:
if self.multiple_targets:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
......
......@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_choice: !function preprocess_winogrande.doc_to_choice
should_decontaminate: true
doc_to_decontamination_query: sentence
multiple_inputs: true
metric_list:
- metric: acc
aggregation: mean
......
......@@ -623,3 +623,7 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL")
else data_dict
)
def validate_index(index: int, length: int) -> int:
return index if index < length else -100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment