Commit 51b67365 authored by Baber's avatar Baber
Browse files

Add utility function to log process results paths in ConfigurableTask

parent 8bc4afff
...@@ -726,6 +726,130 @@ class Task(abc.ABC): ...@@ -726,6 +726,130 @@ class Task(abc.ABC):
return doc_iterator return doc_iterator
def log_process_results_path(task_instance) -> Tuple[str, list[str], str]:
"""
Utility function to determine and log which code paths will be followed
in the process_results method based on the task's initialization state.
Args:
task_instance: The ConfigurableTask instance
Returns:
List of path identifiers that will be taken in process_results
"""
paths = []
task_name = getattr(task_instance.config, "task", "unknown_task")
output_type = task_instance.OUTPUT_TYPE
# CRITICAL: Check for custom process_results override first
if callable(task_instance.config.process_results):
paths.append("custom_process_results_override")
path_summary = " -> ".join(paths)
return task_name, paths, path_summary # Early return - no other paths matter
# Base path is always the output type
base_path = output_type
paths.append(base_path)
if "bypass" in task_instance._metric_fn_list.keys():
paths.append(f"{base_path}_bypass")
path_summary = " -> ".join(paths)
return task_name, paths, path_summary
if output_type == "loglikelihood":
# Simple path - conditional metric inclusion
use_metric = list(task_instance._metric_fn_list.keys())
if "perplexity" in use_metric:
paths.append(f"{base_path}_perplexity")
if "acc" in use_metric:
paths.append(f"{base_path}_acc")
elif output_type == "loglikelihood_rolling":
# Conditional metric inclusion
use_metric = list(task_instance._metric_fn_list.keys())
if "word_perplexity" in use_metric:
paths.append(f"{base_path}_word_perplexity")
if "byte_perplexity" in use_metric:
paths.append(f"{base_path}_byte_perplexity")
if "bits_per_byte" in use_metric:
paths.append(f"{base_path}_bits_per_byte")
elif output_type == "multiple_choice":
# Check for mutual info condition (requires both conditions)
if "acc_mutual_info" in task_instance._metric_fn_list.keys():
paths.append(f"{base_path}_mutual_info_enabled")
# Check for multiple_input condition
if task_instance.multiple_input:
paths.append(f"{base_path}_multiple_input")
# Check for multiple_target condition
if task_instance.multiple_target:
paths.append(f"{base_path}_multiple_target")
else:
paths.append(f"{base_path}_single_target")
# Track potential gold validation issues
paths.append(f"{base_path}_gold_validation")
# Track specific metrics that will be computed
use_metric = list(task_instance._metric_fn_list.keys())
metric_paths = []
for metric in [
"acc",
"f1",
"mcc",
"acc_norm",
"exact_match",
"brier_score",
"acc_mutual_info",
]:
if metric in use_metric:
metric_paths.append(metric)
if metric_paths:
paths.append(f"{base_path}_metrics_{'_'.join(metric_paths)}")
elif output_type == "generate_until":
# Check if doc_to_choice is configured
# Analyze target structure using test document
test_doc = task_instance.task_docs[0]
test_target = task_instance.doc_to_target(test_doc)
if task_instance.config.doc_to_choice is not None:
paths.append(f"{base_path}_with_choices")
# Check if multiple_target (target is a list)
elif task_instance.multiple_target:
paths.append(f"{base_path}_multiple_target")
if not isinstance(test_target, list):
paths.append(f"{base_path}_multiple_target_type_not_list")
elif isinstance(test_target, list):
paths.append(f"{base_path}_target_type_list")
elif not isinstance(test_target, str) and not isinstance(test_target, list):
paths.append(f"{base_path}_target_type_{type(test_target)}")
# Track that we'll loop through all metrics
paths.append(f"{base_path}_metric_processing_loop")
# Check for special metric handling
if task_instance.multiple_target:
if "exact_match" in task_instance._metric_fn_list.keys():
paths.append(f"{base_path}_exact_match_multiple_special")
paths.append(f"{base_path}_multiple_target_metric_aggregation")
# Check for bypass metrics
if "bypass" in task_instance._metric_fn_list.keys():
paths.append(f"{base_path}_bypass_metric")
# Note: Dict result handling and TypeError handling are runtime-dependent
paths.append(f"{base_path}_runtime_dependent_paths")
else:
paths.append("invalid_output_type_error")
# Log the determined paths
path_summary = " -> ".join(paths)
return task_name, paths, path_summary
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "Yaml" VERSION = "Yaml"
OUTPUT_TYPE = None OUTPUT_TYPE = None
...@@ -975,6 +1099,7 @@ class ConfigurableTask(Task): ...@@ -975,6 +1099,7 @@ class ConfigurableTask(Task):
eval_logger.debug( eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
self.paths = log_process_results_path(self)
def download( def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment