Unverified Commit cda25fef authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into standardize_metrics

parents dfb41835 4d10ad56
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -23,7 +23,6 @@ def parse_args(): ...@@ -23,7 +23,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......
...@@ -173,7 +173,6 @@ all_subtasks = [ ...@@ -173,7 +173,6 @@ all_subtasks = [
def main() -> None: def main() -> None:
for path, task_type in zip( for path, task_type in zip(
["multiple_choice", "generate_until"], ["multiple_choice", "generate_until"],
["multiple_choice_template_yaml", "generate_until_template_yaml"], ["multiple_choice_template_yaml", "generate_until_template_yaml"],
......
...@@ -15,4 +15,4 @@ metric_list: ...@@ -15,4 +15,4 @@ metric_list:
higher_is_better: true higher_is_better: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
# Generated by utils.py
dataset_name: causal_judgment_zero_shot
include: ../multiple_choice_template_yaml
task: bigbench_causal_judgement_multiple_choice
...@@ -12,4 +12,4 @@ metric_list: ...@@ -12,4 +12,4 @@ metric_list:
- metric: acc - metric: acc
# TODO: brier score and other metrics # TODO: brier score and other metrics
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -11,4 +11,4 @@ doc_to_decontamination_query: "{{sentence_good}} {{sentence_bad}}" ...@@ -11,4 +11,4 @@ doc_to_decontamination_query: "{{sentence_good}} {{sentence_bad}}"
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -73,7 +73,6 @@ all_subtasks = [ ...@@ -73,7 +73,6 @@ all_subtasks = [
def main() -> None: def main() -> None:
for task in all_subtasks: for task in all_subtasks:
file_name = f"{task}.yaml" file_name = f"{task}.yaml"
try: try:
with open(f"{file_name}", "w") as f: with open(f"{file_name}", "w") as f:
......
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -75,7 +75,6 @@ def parse_args(): ...@@ -75,7 +75,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -93,7 +92,9 @@ if __name__ == "__main__": ...@@ -93,7 +92,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject_eng] description = cot_file[subject_eng]
else: else:
description = f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n" description = (
f"以下是中国关于{subject_zh}的单项选择题,请选出其中的正确答案。\n\n"
)
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
......
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -90,7 +90,6 @@ def parse_args(): ...@@ -90,7 +90,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -108,7 +107,9 @@ if __name__ == "__main__": ...@@ -108,7 +107,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject_eng] description = cot_file[subject_eng]
else: else:
description = f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n" description = (
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
)
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
......
#!/usr/bin/python #!/usr/bin/python
import os
import re import re
import sys import sys
import math import math
import subprocess
import xml.sax.saxutils import xml.sax.saxutils
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
...@@ -65,14 +63,14 @@ def normalize(s): ...@@ -65,14 +63,14 @@ def normalize(s):
if type(s) is not str: if type(s) is not str:
s = " ".join(s) s = " ".join(s)
# language-independent part: # language-independent part:
for (pattern, replace) in normalize1: for pattern, replace in normalize1:
s = re.sub(pattern, replace, s) s = re.sub(pattern, replace, s)
s = xml.sax.saxutils.unescape(s, {""": '"'}) s = xml.sax.saxutils.unescape(s, {""": '"'})
# language-dependent part (assuming Western languages): # language-dependent part (assuming Western languages):
s = " %s " % s s = " %s " % s
if not preserve_case: if not preserve_case:
s = s.lower() # this might not be identical to the original s = s.lower() # this might not be identical to the original
for (pattern, replace) in normalize2: for pattern, replace in normalize2:
s = re.sub(pattern, replace, s) s = re.sub(pattern, replace, s)
return s.split() return s.split()
...@@ -95,7 +93,7 @@ def cook_refs(refs, n=4): ...@@ -95,7 +93,7 @@ def cook_refs(refs, n=4):
maxcounts: Dict[Tuple[str], int] = {} maxcounts: Dict[Tuple[str], int] = {}
for ref in refs: for ref in refs:
counts = count_ngrams(ref, n) counts = count_ngrams(ref, n)
for (ngram, count) in counts.items(): for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
return ([len(ref) for ref in refs], maxcounts) return ([len(ref) for ref in refs], maxcounts)
...@@ -125,7 +123,7 @@ def cook_test(test, item, n=4): ...@@ -125,7 +123,7 @@ def cook_test(test, item, n=4):
result["correct"] = [0] * n result["correct"] = [0] * n
counts = count_ngrams(test, n) counts = count_ngrams(test, n)
for (ngram, count) in counts.items(): for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
return result return result
...@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2): ...@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2):
def smoothed_bleu_4(references, predictions, **kwargs): def smoothed_bleu_4(references, predictions, **kwargs):
predictionMap = {} predictionMap = {}
goldMap = {} goldMap = {}
......
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 2.0 version: 2.0
def doc_to_text(doc): def doc_to_text(doc):
inputs = " ".join(doc["code_tokens"]).replace("\n", " ") inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
inputs = " ".join(inputs.strip().split()) inputs = " ".join(inputs.strip().split())
...@@ -7,7 +6,6 @@ def doc_to_text(doc): ...@@ -7,7 +6,6 @@ def doc_to_text(doc):
def doc_to_target(doc): def doc_to_target(doc):
targets = " ".join(doc["docstring_tokens"]).replace("\n", "") targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
targets = " ".join(targets.strip().split()) targets = " ".join(targets.strip().split())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment