Commit 6ace93ae authored by Baber's avatar Baber
Browse files

add `acc_bytes`

parent 18b910b5
...@@ -179,6 +179,16 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -179,6 +179,16 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items return items
@register_metric(
metric="acc_bytes",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_bytes_fn(items): # This is a passthrough function
return items
### the code used in the `exact_match_hf_evaluate` function is ported from ### the code used in the `exact_match_hf_evaluate` function is ported from
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py ### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
### which is under the apache license. ### which is under the apache license.
......
...@@ -1582,6 +1582,7 @@ class ConfigurableTask(Task): ...@@ -1582,6 +1582,7 @@ class ConfigurableTask(Task):
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
byte_length = np.array([float(len(i.encode("utf-8"))) for i in choices])
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
...@@ -1598,6 +1599,7 @@ class ConfigurableTask(Task): ...@@ -1598,6 +1599,7 @@ class ConfigurableTask(Task):
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
pred_byte = np.argmax(lls / byte_length)
if self.multiple_input: if self.multiple_input:
gold = self.doc_to_text(doc) gold = self.doc_to_text(doc)
...@@ -1627,10 +1629,12 @@ class ConfigurableTask(Task): ...@@ -1627,10 +1629,12 @@ class ConfigurableTask(Task):
if self.multiple_target: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
acc_bytes = 1.0 if pred_byte in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
else: else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
acc_bytes = 1.0 if pred_byte == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0 exact_match = int(is_greedy[gold]) if gold != -100 else 0
...@@ -1643,6 +1647,7 @@ class ConfigurableTask(Task): ...@@ -1643,6 +1647,7 @@ class ConfigurableTask(Task):
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"acc_bytes": acc_bytes} if "acc_bytes" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**( **(
{"brier_score": (gold, prob_norm)} {"brier_score": (gold, prob_norm)}
......
...@@ -101,5 +101,11 @@ aggregate_metric_list: ...@@ -101,5 +101,11 @@ aggregate_metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
weight_by_size: true weight_by_size: true
- metric: acc_norm
aggregation: mean
weight_by_size: true
- metric: acc_bytes
aggregation: mean
weight_by_size: true
metadata: metadata:
version: 1.0 version: 1.0
tag: mrl
dataset_path: mrlbenchmarks/global-piqa-nonparallel dataset_path: mrlbenchmarks/global-piqa-nonparallel
output_type: multiple_choice output_type: multiple_choice
test_split: test test_split: test
...@@ -12,5 +11,8 @@ metric_list: ...@@ -12,5 +11,8 @@ metric_list:
- metric: acc_norm - metric: acc_norm
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
- metric: acc_bytes
aggregation: mean
higher_is_better: true
metadata: metadata:
version: 1.0 version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment