Commit 61d11b8a authored by Baber's avatar Baber
Browse files

require multiple_inputs and multiple_targets to be explicitly set in taskconfig

parent 57b86c47
...@@ -39,6 +39,7 @@ from lm_eval.api.registry import ( ...@@ -39,6 +39,7 @@ from lm_eval.api.registry import (
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.utils import validate_index
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
...@@ -96,6 +97,8 @@ class TaskConfig(dict): ...@@ -96,6 +97,8 @@ class TaskConfig(dict):
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None gen_prefix: Optional[str] = None
multiple_inputs: bool = False
multiple_targets: bool = False
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
...@@ -767,6 +770,12 @@ class ConfigurableTask(Task): ...@@ -767,6 +770,12 @@ class ConfigurableTask(Task):
) )
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
self.multiple_targets = self.config.multiple_targets
self.multiple_inputs = self.config.multiple_inputs
assert not (self.multiple_targets and self.multiple_inputs), (
"Cannot have both multiple_targets and multiple_inputs"
)
if self.config.doc_to_image is not None: if self.config.doc_to_image is not None:
# mark the task as requiring multimodality. # mark the task as requiring multimodality.
self.MULTIMODAL = True self.MULTIMODAL = True
...@@ -923,50 +932,54 @@ class ConfigurableTask(Task): ...@@ -923,50 +932,54 @@ class ConfigurableTask(Task):
# Test One Doc # Test One Doc
self.features = list(self.task_docs.features.keys()) self.features = list(self.task_docs.features.keys())
self.multiple_input = 0
self.multiple_target = 0
test_doc = self.task_docs[0] test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc) test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list): if self.multiple_inputs:
eval_logger.error("doc_to_choice must return list") # we require:
# doc_to_text: int
# doc_to_choice: list
# doc_to_target: str
# e.g. text: 1, choice: [Maria was better than Sarah, Sarah was better than Sarah]
# target: so she was envious
assert isinstance(test_text, int), (
"doc_to_text must return int for multiple inputs"
)
assert isinstance(test_target, str), (
"doc_to_target must return str for multiple inputs"
)
assert self.config.output_type != "generate_until", (
"Only multiple-choice tasks can be used with multiple inputs"
)
test_text = test_choice[0]
elif self.multiple_targets:
# we require:
# doc_to_text: str
# doc_to_choice: list
# doc_to_target: list
assert isinstance(test_target, (list, tuple)), (
"doc_to_target must be an iterable for multiple targets"
)
test_target = test_target[0]
else: else:
num_choice = len(test_choice) assert isinstance(test_target, int), (
"doc_to_target must return int for multiple choices"
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
) )
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target] test_target = test_choice[test_target]
else:
test_target = str(test_target)
if test_choice is not None: assert hasattr(test_choice, "__iter__") and not isinstance(
check_choices = test_choice test_choice, (str, bytes)
else: ), "doc_to_choice must be an iterable!"
check_choices = [test_target]
if self.config.doc_to_choice is not None: for choice in test_choice:
for choice in check_choices: choice_has_whitespace = choice[0].isspace()
choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True self.config.target_delimiter.rstrip()
if self.config.target_delimiter.rstrip()
!= self.config.target_delimiter != self.config.target_delimiter
else False
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -1162,7 +1175,7 @@ class ConfigurableTask(Task): ...@@ -1162,7 +1175,7 @@ class ConfigurableTask(Task):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_inputs:
# TODO: append prefill? # TODO: append prefill?
if not labeled_examples: if not labeled_examples:
return "" return ""
...@@ -1222,7 +1235,7 @@ class ConfigurableTask(Task): ...@@ -1222,7 +1235,7 @@ class ConfigurableTask(Task):
if gen_prefix is not None if gen_prefix is not None
else "" else ""
) )
if self.multiple_input: if self.multiple_inputs:
return labeled_examples return labeled_examples
if isinstance(example, str): if isinstance(example, str):
return labeled_examples + example + prefix return labeled_examples + example + prefix
...@@ -1371,7 +1384,7 @@ class ConfigurableTask(Task): ...@@ -1371,7 +1384,7 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return utils.apply_template(doc_to_choice, doc)
elif isinstance(doc_to_choice, list): elif isinstance(doc_to_choice, list):
return doc_to_choice return doc_to_choice
elif isinstance(doc_to_choice, dict): elif isinstance(doc_to_choice, dict):
...@@ -1454,7 +1467,7 @@ class ConfigurableTask(Task): ...@@ -1454,7 +1467,7 @@ class ConfigurableTask(Task):
target_delimiter = self.config.target_delimiter target_delimiter = self.config.target_delimiter
if apply_chat_template: if apply_chat_template:
target_delimiter = "" target_delimiter = ""
if self.multiple_input: if self.multiple_inputs:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template # apply chat_template to choices if apply_chat_template
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
...@@ -1595,24 +1608,22 @@ class ConfigurableTask(Task): ...@@ -1595,24 +1608,22 @@ class ConfigurableTask(Task):
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
if self.multiple_input: gold = (
gold = self.doc_to_text(doc) self.doc_to_text(doc)
else: if not self.multiple_inputs
gold = self.doc_to_target(doc) else self.doc_to_text(doc)
)
gold_index_error = False
if isinstance(gold, list): if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold] gold = [validate_index(g, len(choices)) for g in gold]
if -100 in gold: gold_index_error = -100 in gold
gold_index_error = True
else: else:
if isinstance(gold, int): if isinstance(gold, int):
gold = gold if gold < len(choices) else -100 gold = validate_index(gold, len(choices))
elif isinstance(gold, str): elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100 gold = choices.index(gold) if gold in choices else -100
if gold == -100: gold_index_error = gold == -100
gold_index_error = True
if gold_index_error: if gold_index_error:
eval_logger.warning( eval_logger.warning(
...@@ -1620,7 +1631,7 @@ class ConfigurableTask(Task): ...@@ -1620,7 +1631,7 @@ class ConfigurableTask(Task):
f"Sample:\n\n{doc}\n\n" f"Sample:\n\n{doc}\n\n"
) )
if self.multiple_target: if self.multiple_targets:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
......
...@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target ...@@ -9,6 +9,7 @@ doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_choice: !function preprocess_winogrande.doc_to_choice doc_to_choice: !function preprocess_winogrande.doc_to_choice
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: sentence doc_to_decontamination_query: sentence
multiple_inputs: true
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -623,3 +623,7 @@ def hash_dict_images(data_dict): ...@@ -623,3 +623,7 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL") if importlib.util.find_spec("PIL")
else data_dict else data_dict
) )
def validate_index(index: int, length: int) -> int:
return index if index < length else -100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment