Unverified Commit 8d1d003d authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Update task.py

parent cddbf9f6
......@@ -1017,16 +1017,16 @@ class ConfigurableTask(Task):
gold_index_error = False
if type(gold) is list:
gold = [i if i < len(choices) else -1000 for i in gold]
if -1000 in gold:
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if type(gold) is int:
gold = gold if gold < len(choices) else -1000
gold = gold if gold < len(choices) else -100
elif type(gold) is str:
gold = choices.index(gold) if gold in choices else -1000
gold = choices.index(gold) if gold in choices else -100
if gold == -1000:
if gold == -100:
gold_index_error = True
if gold_index_error:
......@@ -1039,13 +1039,13 @@ class ConfigurableTask(Task):
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(
any([is_greedy[i] if i != -1000 else 0 for i in gold])
any([is_greedy[i] if i != -100 else 0 for i in gold])
)
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -1000 else 0
exact_match = int(is_greedy[gold]) if gold != -100 else 0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment