Commit c407be5b authored by Baber's avatar Baber
Browse files

move agg metrics to after gather

parent b6ae8c4a
...@@ -792,6 +792,8 @@ class ConfigurableTask(Task): ...@@ -792,6 +792,8 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name self.DATASET_NAME = self.config.dataset_name
self.metric_results = []
self._metric_fn_list = {} self._metric_fn_list = {}
self._metric_fn_kwargs = {} self._metric_fn_kwargs = {}
self._aggregation_list = {} self._aggregation_list = {}
...@@ -1865,14 +1867,15 @@ class ConfigurableTask(Task): ...@@ -1865,14 +1867,15 @@ class ConfigurableTask(Task):
) )
for metric_name, _score in _sample_metric.items(): for metric_name, _score in _sample_metric.items():
_all_metrics[(metric_name, filter_key)].append(_score) _all_metrics[(metric_name, filter_key)].append(_score)
self.metric_results = _all_metrics
return _all_metrics, _samples return _all_metrics, _samples
def compute_agg_metrics( def compute_agg_metrics(
self, self,
metric_results: dict[tuple[str, str], list[list[float]]], metric_results: dict[tuple[str, str], list[list[float]]] = None,
bootstrap_iters: int = 1000, bootstrap_iters: int = 1000,
): ):
metric_results = metric_results if metric_results else self.metric_results
agg_metrics = defaultdict(list) agg_metrics = defaultdict(list)
for (metric_name, filter_key), scores in metric_results.items(): for (metric_name, filter_key), scores in metric_results.items():
agg_fn = self.aggregation()[metric_name] agg_fn = self.aggregation()[metric_name]
......
...@@ -588,21 +588,12 @@ def evaluate( ...@@ -588,21 +588,12 @@ def evaluate(
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task # # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
# instances_by_doc_id = defaultdict(list)
# for instance in task.instances:
# instances_by_doc_id[instance.doc_id].append(instance)
# # Sort instances within each group
# for instances in instances_by_doc_id.values():
# instances.sort(key=lambda x: x.idx)
# iterate over different filters used
_metrics, samples = task.calculate_metrics( _metrics, samples = task.calculate_metrics(
indices=samples, indices=samples,
rank=RANK, rank=RANK,
limit=limit, limit=limit,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
) )
task_output.agg_metrics = task.compute_agg_metrics(_metrics)
task_output.sample_metrics = _metrics task_output.sample_metrics = _metrics
if log_samples: if log_samples:
task_output.logged_samples = samples task_output.logged_samples = samples
...@@ -641,8 +632,10 @@ def evaluate( ...@@ -641,8 +632,10 @@ def evaluate(
if RANK == 0: if RANK == 0:
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
# for task_output in eval_tasks: for task_output in eval_tasks:
# task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) task_output.agg_metrics = task_output.task.compute_agg_metrics(
bootstrap_iters=bootstrap_iters
)
( (
results, results,
samples, samples,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment