"examples/flash_attention/example_gqa_bwd_tma_reduce.py" did not exist on "557589ffd7af10f2740d4bbf5f4f0ce70305ea3c"
Commit 897fbb37 authored by Baber's avatar Baber
Browse files

refactor: improve default behavior for metric aggregation and higher-better checks

parent 5c3badbe
...@@ -167,20 +167,24 @@ def get_aggregation(name: str) -> Callable[..., Any] | None: ...@@ -167,20 +167,24 @@ def get_aggregation(name: str) -> Callable[..., Any] | None:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None: def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]]:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(
f"{name} metric is not assigned a default aggregation!. Using default aggregation mean"
)
return AGGREGATION_REGISTRY["mean"]
def is_higher_better(metric_name: str) -> bool | None: def is_higher_better(metric_name: str) -> bool:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
eval_logger.warning( eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!" f"higher_is_better not specified for metric '{metric_name}'!. Will default to True."
) )
return True
def register_filter(name: str): def register_filter(name: str):
......
...@@ -240,7 +240,7 @@ class TaskConfig(dict): ...@@ -240,7 +240,7 @@ class TaskConfig(dict):
name=metric_name, name=metric_name,
fn=get_metric(metric_name), fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name), aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name), higher_is_better=is_higher_better(metric_name) or True,
) )
for metric_name in _metric_list for metric_name in _metric_list
) )
......
...@@ -31,7 +31,7 @@ class TemplateConfig: ...@@ -31,7 +31,7 @@ class TemplateConfig:
@dataclass @dataclass
class MCQTemplateConfig: class MCQTemplateConfig(TemplateConfig):
"""Encapsulates information about a template. """Encapsulates information about a template.
Would return a sample with the following format: Would return a sample with the following format:
Question: <doc_to_text(doc)> Question: <doc_to_text(doc)>
...@@ -58,7 +58,7 @@ class MCQTemplateConfig: ...@@ -58,7 +58,7 @@ class MCQTemplateConfig:
@dataclass @dataclass
class ClozeTemplateConfig: class ClozeTemplateConfig(TemplateConfig):
"""Encapsulates information about a template. """Encapsulates information about a template.
Would return a sample with the following format: Would return a sample with the following format:
Question: <doc_to_text(doc)> Question: <doc_to_text(doc)>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment