Commit 0d1ef037 authored by lintangsutawika's avatar lintangsutawika
Browse files

solved merge conflict

parents aa44be3f ada4a31d
......@@ -16,4 +16,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 0.0
version: 0.0
......@@ -90,7 +90,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -108,7 +107,9 @@ if __name__ == "__main__":
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
description = (
f"以下是关于{subject_zh}的单项选择题,请直接给出正确答案的选项。\n\n"
)
yaml_dict = {
"include": base_yaml_name,
......
#!/usr/bin/python
import os
import re
import sys
import math
import subprocess
import xml.sax.saxutils
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
......@@ -65,14 +63,14 @@ def normalize(s):
if type(s) is not str:
s = " ".join(s)
# language-independent part:
for (pattern, replace) in normalize1:
for pattern, replace in normalize1:
s = re.sub(pattern, replace, s)
s = xml.sax.saxutils.unescape(s, {""": '"'})
# language-dependent part (assuming Western languages):
s = " %s " % s
if not preserve_case:
s = s.lower() # this might not be identical to the original
for (pattern, replace) in normalize2:
for pattern, replace in normalize2:
s = re.sub(pattern, replace, s)
return s.split()
......@@ -95,7 +93,7 @@ def cook_refs(refs, n=4):
maxcounts: Dict[Tuple[str], int] = {}
for ref in refs:
counts = count_ngrams(ref, n)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
return ([len(ref) for ref in refs], maxcounts)
......@@ -125,7 +123,7 @@ def cook_test(test, item, n=4):
result["correct"] = [0] * n
counts = count_ngrams(test, n)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
return result
......@@ -222,7 +220,6 @@ def bleuFromMaps(m1, m2):
def smoothed_bleu_4(references, predictions, **kwargs):
predictionMap = {}
goldMap = {}
......
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 1.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 1.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 1.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 1.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 0.0
version: 1.0
......@@ -18,4 +18,4 @@ metric_list:
aggregation: mean
higher_is_better: True
metadata:
- version: 2.0
version: 3.0
def doc_to_text(doc):
inputs = " ".join(doc["code_tokens"]).replace("\n", " ")
inputs = " ".join(inputs.strip().split())
......@@ -7,7 +6,6 @@ def doc_to_text(doc):
def doc_to_target(doc):
targets = " ".join(doc["docstring_tokens"]).replace("\n", "")
targets = " ".join(targets.strip().split())
......
......@@ -19,4 +19,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 2.0
version: 3.0
......@@ -7,7 +7,7 @@ def doc_to_text(doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + "\n\n"
for (q, a) in zip_longest(
for q, a in zip_longest(
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
): # omit target answer ai
question = f"Q: {q}\n\n"
......@@ -17,7 +17,6 @@ def doc_to_text(doc):
def doc_to_target(doc):
turn_id = len(doc["questions"]["input_text"])
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
......@@ -71,7 +70,6 @@ def compute_scores(gold_list, pred):
def process_results(doc, results):
gold_list = doc_to_target(doc)
pred = results[0].strip().split("\n")[0]
......
......@@ -20,4 +20,4 @@ metric_list:
aggregation: mean
higher_is_better: false
metadata:
- version: 1.0
version: 1.0
......@@ -14,4 +14,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 0.0
version: 0.0
......@@ -21,7 +21,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
......@@ -30,7 +29,6 @@ if __name__ == "__main__":
base_yaml = yaml.full_load(f)
for name in tqdm(SUBSETS):
yaml_dict = {
"include": base_yaml_name,
"task": f"csatqa_{args.task_prefix}_{name}"
......
......@@ -21,4 +21,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 2.0
version: 3.0
......@@ -62,7 +62,6 @@ def parse_answer(answer):
def process_results(doc, results):
preds, golds = results, doc["answers"]
max_em = 0
max_f1 = 0
......
......@@ -12,3 +12,10 @@ metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
metadata:
version: 2.0
include: fld.yaml
include: fld_default.yaml
task: fld_star
dataset_name: star
......@@ -13,4 +13,4 @@ doc_to_decontamination_query: sentence
metric_list:
- metric: mcc
metadata:
- version: 1.0
version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment