Commit e37698df authored by lintangsutawika's avatar lintangsutawika
Browse files

update on metrics and delet files

parent 1bc408ff
......@@ -555,12 +555,17 @@ class ConfigurableTask(Task):
if key not in ["metric", "aggregation", "higher_is_better"]
}
if self._config.process_results is None:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs
else:
if self._config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
......@@ -987,6 +992,7 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc)
gold = choices[gold]
print(self._metric_fn_list)
for key, result in zip(self._metric_fn_list.keys(), results):
if self.multiple_target:
# in the case where we have multiple targets,
......
......@@ -419,10 +419,14 @@ def evaluate(
versions[group] = "N/A"
results_dict = {
"results": dict(results),
**({"aggregate": dict(aggregate)} if bool(aggregate) else {}),
"configs": dict(configs),
"versions": dict(versions),
"results": dict(sorted(results.items())),
**(
{"aggregate": dict(sorted(aggregate.items()))}
if bool(aggregate)
else {}
),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
}
if log_samples:
results_dict["samples"] = dict(samples)
......
......@@ -30,6 +30,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -37,17 +38,17 @@ task:
ignore_case: true
ignore_punctuation: true
# Natural Language Inference
- dataset_path: super_glue
dataset_name: rte
use_prompt: promptsource:*
training_split: train
validation_split: validation
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# - dataset_path: super_glue
# dataset_name: rte
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Natural Language Inference
# # - dataset_path: anli
# # use_prompt: promptsource:*
......
......@@ -15,5 +15,5 @@ metric_list:
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# - metric: f1
# aggregation: !function "aggregate.cb_multi_fi"
- metric: f1
aggregation: !function "aggregate.cb_multi_fi"
......@@ -6,7 +6,7 @@ dataset_name: rte
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{sentence1}}\nQuestion: {{sentence2}} True or False?\nAnswer:"
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}} True or False?\nAnswer:"
doc_to_target: label
doc_to_choice: ['True', 'False']
metric_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment