Commit c407be5b authored by Baber's avatar Baber
Browse files

move agg metrics to after gather

parent b6ae8c4a
......@@ -792,6 +792,8 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self.metric_results = []
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
......@@ -1865,14 +1867,15 @@ class ConfigurableTask(Task):
)
for metric_name, _score in _sample_metric.items():
_all_metrics[(metric_name, filter_key)].append(_score)
self.metric_results = _all_metrics
return _all_metrics, _samples
def compute_agg_metrics(
self,
metric_results: dict[tuple[str, str], list[list[float]]],
metric_results: dict[tuple[str, str], list[list[float]]] = None,
bootstrap_iters: int = 1000,
):
metric_results = metric_results if metric_results else self.metric_results
agg_metrics = defaultdict(list)
for (metric_name, filter_key), scores in metric_results.items():
agg_fn = self.aggregation()[metric_name]
......
......@@ -588,21 +588,12 @@ def evaluate(
### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
# instances_by_doc_id = defaultdict(list)
# for instance in task.instances:
# instances_by_doc_id[instance.doc_id].append(instance)
# # Sort instances within each group
# for instances in instances_by_doc_id.values():
# instances.sort(key=lambda x: x.idx)
# iterate over different filters used
_metrics, samples = task.calculate_metrics(
indices=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
task_output.agg_metrics = task.compute_agg_metrics(_metrics)
task_output.sample_metrics = _metrics
if log_samples:
task_output.logged_samples = samples
......@@ -641,8 +632,10 @@ def evaluate(
if RANK == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
# for task_output in eval_tasks:
# task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
for task_output in eval_tasks:
task_output.agg_metrics = task_output.task.compute_agg_metrics(
bootstrap_iters=bootstrap_iters
)
(
results,
samples,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment