Commit 6ace93ae authored by Baber's avatar Baber
Browse files

add `acc_bytes`

parent 18b910b5
......@@ -179,6 +179,16 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_bytes",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_bytes_fn(items): # This is a passthrough function
return items
### the code used in the `exact_match_hf_evaluate` function is ported from
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
### which is under the apache license.
......
......@@ -1582,6 +1582,7 @@ class ConfigurableTask(Task):
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
byte_length = np.array([float(len(i.encode("utf-8"))) for i in choices])
if (
2 * len(choices) == len(lls)
......@@ -1598,6 +1599,7 @@ class ConfigurableTask(Task):
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
pred_byte = np.argmax(lls / byte_length)
if self.multiple_input:
gold = self.doc_to_text(doc)
......@@ -1627,10 +1629,12 @@ class ConfigurableTask(Task):
if self.multiple_target:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
acc_bytes = 1.0 if pred_byte in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
acc_bytes = 1.0 if pred_byte == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0
......@@ -1643,6 +1647,7 @@ class ConfigurableTask(Task):
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"acc_bytes": acc_bytes} if "acc_bytes" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
......
......@@ -101,5 +101,11 @@ aggregate_metric_list:
- metric: acc
aggregation: mean
weight_by_size: true
- metric: acc_norm
aggregation: mean
weight_by_size: true
- metric: acc_bytes
aggregation: mean
weight_by_size: true
metadata:
version: 1.0
tag: mrl
dataset_path: mrlbenchmarks/global-piqa-nonparallel
output_type: multiple_choice
test_split: test
......@@ -12,5 +11,8 @@ metric_list:
- metric: acc_norm
aggregation: mean
higher_is_better: true
- metric: acc_bytes
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment