Commit 0d1ef037 authored by lintangsutawika's avatar lintangsutawika
Browse files

solved merge conflict

parents aa44be3f ada4a31d
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -90,7 +90,6 @@ def parse_args(): ...@@ -90,7 +90,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -108,7 +107,9 @@ if __name__ == "__main__": ...@@ -108,7 +107,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None: if args.cot_prompt_path is not None:
description = cot_file[subject_eng] description = cot_file[subject_eng]
else: else:
description = f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n" description = (
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
)
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
......
#!/usr/bin/python #!/usr/bin/python
import os
import re import re
import sys import sys
import math import math
import subprocess
import xml.sax.saxutils import xml.sax.saxutils
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
...@@ -65,14 +63,14 @@ def normalize(s): ...@@ -65,14 +63,14 @@ def normalize(s):
if type(s) is not str: if type(s) is not str:
s = " ".join(s) s = " ".join(s)
# language-independent part: # language-independent part:
for (pattern, replace) in normalize1: for pattern, replace in normalize1:
s = re.sub(pattern, replace, s) s = re.sub(pattern, replace, s)
s = xml.sax.saxutils.unescape(s, {""": '"'}) s = xml.sax.saxutils.unescape(s, {""": '"'})
# language-dependent part (assuming Western languages): # language-dependent part (assuming Western languages):
s = " %s " % s s = " %s " % s
if not preserve_case: if not preserve_case:
s = s.lower() # this might not be identical to the original s = s.lower() # this might not be identical to the original
for (pattern, replace) in normalize2: for pattern, replace in normalize2:
s = re.sub(pattern, replace, s) s = re.sub(pattern, replace, s)
return s.split() return s.split()
...@@ -95,7 +93,7 @@ def cook_refs(refs, n=4): ...@@ -95,7 +93,7 @@ def cook_refs(refs, n=4):
maxcounts: Dict[Tuple[str], int] = {} maxcounts: Dict[Tuple[str], int] = {}
for ref in refs: for ref in refs:
counts = count_ngrams(ref, n) counts = count_ngrams(ref, n)
for (ngram, count) in counts.items(): for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
return ([len(ref) for ref in refs], maxcounts) return ([len(ref) for ref in refs], maxcounts)
...@@ -125,7 +123,7 @@ def cook_test(test, item, n=4): ...@@ -125,7 +123,7 @@ def cook_test(test, item, n=4):
result["correct"] = [0] * n result["correct"] = [0] * n
counts = count_ngrams(test, n) counts = count_ngrams(test, n)
for (ngram, count) in counts.items(): for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
return result return result
...@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2): ...@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2):
def smoothed_bleu_4(references, predictions, **kwargs): def smoothed_bleu_4(references, predictions, **kwargs):
predictionMap = {} predictionMap = {}
goldMap = {} goldMap = {}
......
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -18,4 +18,4 @@ metric_list: ...@@ -18,4 +18,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
- version: 2.0 version: 3.0
def doc_to_text(doc): def doc_to_text(doc):
inputs = " ".join(doc["code_tokens"]).replace("\n", " ") inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
inputs = " ".join(inputs.strip().split()) inputs = " ".join(inputs.strip().split())
...@@ -7,7 +6,6 @@ def doc_to_text(doc): ...@@ -7,7 +6,6 @@ def doc_to_text(doc):
def doc_to_target(doc): def doc_to_target(doc):
targets = " ".join(doc["docstring_tokens"]).replace("\n", "") targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
targets = " ".join(targets.strip().split()) targets = " ".join(targets.strip().split())
......
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 2.0 version: 3.0
...@@ -7,7 +7,7 @@ def doc_to_text(doc): ...@@ -7,7 +7,7 @@ def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n" doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest( for q, a in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1] doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai ): # omit target answer ai
question = f"Q: {q}\n\n" question = f"Q: {q}\n\n"
...@@ -17,7 +17,6 @@ def doc_to_text(doc): ...@@ -17,7 +17,6 @@ def doc_to_text(doc):
def doc_to_target(doc): def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"]) turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
...@@ -71,7 +70,6 @@ def compute_scores(gold_list, pred): ...@@ -71,7 +70,6 @@ def compute_scores(gold_list, pred):
def process_results(doc, results): def process_results(doc, results):
gold_list = doc_to_target(doc) gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0] pred = results[0].strip().split("\n")[0]
......
...@@ -20,4 +20,4 @@ metric_list: ...@@ -20,4 +20,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: false higher_is_better: false
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -14,4 +14,4 @@ metric_list: ...@@ -14,4 +14,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -21,7 +21,6 @@ def parse_args(): ...@@ -21,7 +21,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -30,7 +29,6 @@ if __name__ == "__main__": ...@@ -30,7 +29,6 @@ if __name__ == "__main__":
base_yaml = yaml.full_load(f) base_yaml = yaml.full_load(f)
for name in tqdm(SUBSETS): for name in tqdm(SUBSETS):
yaml_dict = { yaml_dict = {
"include": base_yaml_name, "include": base_yaml_name,
"task": f"csatqa_{args.task_prefix}_{name}" "task": f"csatqa_{args.task_prefix}_{name}"
......
...@@ -21,4 +21,4 @@ metric_list: ...@@ -21,4 +21,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 2.0 version: 3.0
...@@ -62,7 +62,6 @@ def parse_answer(answer): ...@@ -62,7 +62,6 @@ def parse_answer(answer):
def process_results(doc, results): def process_results(doc, results):
preds, golds = results, doc["answers"] preds, golds = results, doc["answers"]
max_em = 0 max_em = 0
max_f1 = 0 max_f1 = 0
......
...@@ -12,3 +12,10 @@ metric_list: ...@@ -12,3 +12,10 @@ metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
metadata:
version: 2.0
include: fld.yaml include: fld_default.yaml
task: fld_star task: fld_star
dataset_name: star dataset_name: star
...@@ -13,4 +13,4 @@ doc_to_decontamination_query: sentence ...@@ -13,4 +13,4 @@ doc_to_decontamination_query: sentence
metric_list: metric_list:
- metric: mcc - metric: mcc
metadata: metadata:
- version: 1.0 version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment