Unverified Commit cda25fef authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into standardize_metrics

parents dfb41835 4d10ad56
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 0.0
version: 0.0
......@@ -23,7 +23,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......
......@@ -173,7 +173,6 @@ all_subtasks = [
def main() -> None:
for path, task_type in zip(
["multiple_choice", "generate_until"],
["multiple_choice_template_yaml", "generate_until_template_yaml"],
......
......@@ -15,4 +15,4 @@ metric_list:
higher_is_better: true
ignore_punctuation: true
metadata:
- version: 0.0
version: 0.0
# Generated by utils.py
dataset_name: causal_judgment_zero_shot
include: ../multiple_choice_template_yaml
task: bigbench_causal_judgement_multiple_choice
......@@ -12,4 +12,4 @@ metric_list:
- metric: acc
# TODO: brier score and other metrics
metadata:
- version: 0.0
version: 0.0
......@@ -11,4 +11,4 @@ doc_to_decontamination_query: "{{sentence_good}} {{sentence_bad}}"
metric_list:
- metric: acc
metadata:
- version: 1.0
version: 1.0
......@@ -73,7 +73,6 @@ all_subtasks = [
def main() -> None:
for task in all_subtasks:
file_name = f"{task}.yaml"
try:
with open(f"{file_name}", "w") as f:
......
......@@ -16,4 +16,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
......@@ -75,7 +75,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -93,7 +92,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n"
description = (
f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n"
)
yaml_dict = {
"include": base_yaml_name,
......
......@@ -16,4 +16,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 0.0
version: 0.0
......@@ -90,7 +90,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -108,7 +107,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
description = (
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
)
yaml_dict = {
"include": base_yaml_name,
......
#!/usr/bin/python
import os
import re
import sys
import math
import subprocess
import xml.sax.saxutils
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
......@@ -65,14 +63,14 @@ def normalize(s):
if type(s) is not str:
s = " ".join(s)
# language-independent part:
for (pattern, replace) in normalize1:
for pattern, replace in normalize1:
s = re.sub(pattern, replace, s)
s = xml.sax.saxutils.unescape(s, {""": '"'})
# language-dependent part (assuming Western languages):
s = " %s " % s
if not preserve_case:
s = s.lower() # this might not be identical to the original
for (pattern, replace) in normalize2:
for pattern, replace in normalize2:
s = re.sub(pattern, replace, s)
return s.split()
......@@ -95,7 +93,7 @@ def cook_refs(refs, n=4):
maxcounts: Dict[Tuple[str], int] = {}
for ref in refs:
counts = count_ngrams(ref, n)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
return ([len(ref) for ref in refs], maxcounts)
......@@ -125,7 +123,7 @@ def cook_test(test, item, n=4):
result["correct"] = [0] * n
counts = count_ngrams(test, n)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
return result
......@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2):
def smoothed_bleu_4(references, predictions, **kwargs):
predictionMap = {}
goldMap = {}
......
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 0.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 0.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 0.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 0.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 0.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 2.0
version: 2.0
def doc_to_text(doc):
inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
inputs = " ".join(inputs.strip().split())
......@@ -7,7 +6,6 @@ def doc_to_text(doc):
def doc_to_target(doc):
targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
targets = " ".join(targets.strip().split())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment