Unverified Commit 861942ab authored by Leymore's avatar Leymore Committed by GitHub
Browse files

[Feature] Add lawbench (#460)

* add lawbench

* update requirements

* update
parent fbf5089c
"""
task: multiple choice classification
metric: F1 score
婚姻文本分类
"""
def compute_wbfl(data_dict):
"""
A reference (R) contains a list of options, each option is from the option_list.
We will extract the options appearing in the prediction and convert them into a set (P).
We compute the F1 score between the prediction (P) and the reference (R).
"""
score_list, abstentions = [], 0
option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局",
"二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务",
"存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"]
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}"
gt_list = (answer[3:-1].split("、"))
for gt in gt_list:
assert gt in option_list, f"gt: {gt}, question: {question}"
gt_set = set(gt_list)
prediction_list = []
for option in option_list:
if option in prediction:
prediction_list.append(option)
if len(prediction_list) == 0:
abstentions += 1
predict_set = set(prediction_list)
precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0
recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
score_list.append(f1_score)
# compute the accuracy of score_list
final_f1_score = sum(score_list) / len(score_list)
return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)}
import re
import os
import subprocess
"""
Task: legal document grammar correction
Metric: F0.5 score
文书校对
"""
def compute_wsjd(data_dict):
origins, references, predictions = [], [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
if isinstance(question, list):
question = question[0]['prompt']
start = question.index('句子:\n') + 4
origins.append(re.sub(r'\n|\t', '', question[start:].split('\n')[0]))
# truncate predictions >5 tokens longer than the reference
prediction = re.sub(r'\n|\t', '', prediction)
if len(prediction) - len(answer) > 5:
prediction = prediction[:len(answer) + 5]
if len(prediction) == 0:
prediction = "无内容"
predictions.append(prediction)
references.append(re.sub(r'\n|\t', '', answer))
#generate input files for ChERRANT
preds = [f'{i} \t {origin} \t {prediction} \n' for i, (origin, prediction) in enumerate(zip(origins, predictions))]
golds = [f'{i} \t {origin} \t {reference} \n' for i, (origin, reference) in enumerate(zip(origins, references))]
now_path = os.path.abspath(os.getcwd())
utils_path = os.path.abspath(os.path.join(__file__, '..', '..', 'utils'))
uid = os.getuid()
os.chdir(utils_path)
with open(f'/tmp/tmp_pred_{uid}.para', 'w') as f:
f.writelines(preds)
with open(f'/tmp/tmp_gold_{uid}.para', 'w') as f:
f.writelines(golds)
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_pred_{uid}.para -o /tmp/tmp_pred_{uid}.para.m2 -g char')
os.system(f'python3 parallel_to_m2.py -f /tmp/tmp_gold_{uid}.para -o /tmp/tmp_gold_{uid}.para.m2 -g char')
output = subprocess.check_output(f"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_{uid}.para.m2 -ref /tmp/tmp_gold_{uid}.para.m2", shell = True)
score = float(output.decode().split('\t')[-1].split('\n')[0])
#remove prediction files
os.remove(f'/tmp/tmp_pred_{uid}.para')
os.remove(f'/tmp/tmp_gold_{uid}.para')
os.remove(f'/tmp/tmp_pred_{uid}.para.m2')
os.remove(f'/tmp/tmp_gold_{uid}.para.m2')
os.chdir(now_path)
return {"score": score}
from ..utils.comprehension_scores import compute_ie_f1
"""
task: information extraction
metric: F1 score
信息抽取
"""
def compute_xxcq(data_dict):
references, predictions = [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
predictions.append(prediction)
references.append(answer)
return compute_ie_f1(predictions, references, {"犯罪嫌疑人", "受害人", "被盗货币", "物品价值", "盗窃获利",
"被盗物品", "作案工具", "时间", "地点", "组织机构"})
from ..utils.comprehension_scores import compute_rc_f1
"""
Task: machine reading comprehension
Metric: F1 score
法律阅读理解
"""
def compute_ydlj(data_dict):
references, predictions = [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
answer = answer.replace("回答:", "")
predictions.append(prediction)
references.append(answer)
f1_score = compute_rc_f1(predictions, references)
return f1_score
from ..utils.function_utils import compute_rouge
#舆情摘要
def compute_yqzy(data_dict):
"""
Compute the ROUGE-L score between the prediction and the reference
"""
references, predictions = [], []
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
predictions.append(prediction)
references.append(answer)
# compute the accuracy of score_list
rouge_scores = compute_rouge(predictions, references)
rouge_ls = [score["rouge-l"]["f"] for score in rouge_scores]
average_rouge_l = sum(rouge_ls) / len(rouge_ls)
return {"score": average_rouge_l}
from ..utils.function_utils import multi_choice_judge
"""
task: multiple choice classification
metric: accuracy
咨询分类
"""
def compute_zxfl(data_dict):
"""
A reference (R) contains a list of options, each option is from the option_list.
We will extract the options appearing in the prediction and convert them into a set (P).
We compute the accuracy between the prediction (P) and the reference (R).
"""
score_list, abstentions = [], 0
option_list = ['婚姻家庭', '劳动纠纷', '交通事故', '债权债务', '刑事辩护', '合同纠纷', '房产纠纷', '侵权', '公司法', '医疗纠纷', '拆迁安置', '行政诉讼', '建设工程', '知识产权', '综合咨询', '人身损害', '涉外法律', '海事海商', '消费权益', '抵押担保']
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
judge = multi_choice_judge(prediction, option_list, answer)
score_list.append(judge["score"])
abstentions += judge["abstention"]
# compute the accuracy of score_list
final_accuracy_score = sum(score_list) / len(score_list)
return {'score': final_accuracy_score, 'abstention_rate': abstentions / len(data_dict)}
import json
import os
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from ..base import BaseDataset
from .evaluation_functions import (cjft, flzx, ftcs, jdzy, jec_ac, jec_kd,
jetq, lblj, ljp_accusation, ljp_article,
ljp_imprison, sjjc, wbfl, wsjd, xxcq, ydlj,
yqzy, zxfl)
@LOAD_DATASET.register_module()
class LawBenchDataset(BaseDataset):
@staticmethod
def load(path: str, index: str) -> Dataset:
path = os.path.join(path, index + '.json')
with open(path, 'r') as f:
data = json.load(f)
return Dataset.from_list(data)
funct_dict = {
'1-1': ftcs.compute_ftcs,
'1-2': jec_kd.compute_jec_kd,
'2-1': wsjd.compute_wsjd,
'2-2': jdzy.compute_jdzy,
'2-3': wbfl.compute_wbfl,
'2-4': zxfl.compute_zxfl,
'2-5': ydlj.compute_ydlj,
'2-6': xxcq.compute_xxcq,
'2-7': yqzy.compute_yqzy,
'2-8': lblj.compute_lblj,
'2-9': sjjc.compute_sjjc,
'2-10': sjjc.compute_cfcy,
'3-1': ljp_article.compute_ljp_article,
'3-2': cjft.compute_cjft,
'3-3': ljp_accusation.compute_ljp_accusation,
'3-4': ljp_imprison.compute_ljp_imprison,
'3-5': ljp_imprison.compute_ljp_imprison,
'3-6': jec_ac.compute_jec_ac,
'3-7': jetq.compute_jetq,
'3-8': flzx.compute_flzx,
}
class LawBenchEvaluator(BaseEvaluator):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def score(self, predictions, references, origin_prompt):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
data_dict = [{
'origin_prompt': origin_prompt[i],
'prediction': predictions[i],
'refr': references[i],
} for i in range(len(predictions))]
scores = funct_dict[self.index](data_dict)
scores = {k: v * 100 for k, v in scores.items()}
return scores
for index in funct_dict:
# fix classic closure problem
def _register(index):
ICL_EVALUATORS.register_module(
name='LawBenchEvaluator_' + index.replace('-', '_'),
module=lambda *args, **kwargs: LawBenchEvaluator(
index=index, *args, **kwargs))
_register(index)
### Copy from https://github.com/iqiyi/FASPell ###
"""
Requirements:
- java (required only if tree edit distance is used)
- numpy
"""
import numpy as np
from subprocess import Popen, PIPE, STDOUT
import os
import argparse
IDCS = {'\u2ff0': 2, # 12 ideographic description characters and their capacity of son nodes
'\u2ff1': 2,
'\u2ff2': 3,
'\u2ff3': 3,
'\u2ff4': 2,
'\u2ff5': 2,
'\u2ff6': 2,
'\u2ff7': 2,
'\u2ff8': 2,
'\u2ff9': 2,
'\u2ffa': 2,
'\u2ffb': 2, }
PINYIN = {'ā': ['a', 1], 'á': ['a', 2], 'ǎ': ['a', 3], 'à': ['a', 4],
'ē': ['e', 1], 'é': ['e', 2], 'ě': ['e', 3], 'è': ['e', 4],
'ī': ['i', 1], 'í': ['i', 2], 'ǐ': ['i', 3], 'ì': ['i', 4],
'ō': ['o', 1], 'ó': ['o', 2], 'ǒ': ['o', 3], 'ò': ['o', 4],
'ū': ['u', 1], 'ú': ['u', 2], 'ǔ': ['u', 3], 'ù': ['u', 4],
'ǖ': ['ü', 1], 'ǘ': ['ü', 2], 'ǚ': ['ü', 3], 'ǜ': ['ü', 4],
'': ['m', 2], 'ń': ['n', 2], 'ň': ['n', 3], 'ǹ': ['n', 4],
}
# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar')
APTED_JAR_PATH = 'apted.jar'
def tree_edit_distance(tree_a, tree_b):
"""
We use APTED algorithm proposed by M. Pawlik and N. Augsten
github link: https://github.com/DatabaseGroup/apted
"""
p = Popen(['java', '-jar', APTED_JAR_PATH, '-t', tree_a, tree_b], stdout=PIPE, stderr=STDOUT)
res = [line for line in p.stdout]
res = res[0]
res = res.strip()
res = float(res)
return res
def edit_distance(string_a, string_b, name='Levenshtein'):
"""
>>> edit_distance('abcde', 'avbcude')
2
>>> edit_distance(['至', '刂'], ['亻', '至', '刂'])
1
>>> edit_distance('fang', 'qwe')
4
>>> edit_distance('fang', 'hen')
3
"""
size_x = len(string_a) + 1
size_y = len(string_b) + 1
matrix = np.zeros((size_x, size_y), dtype=int)
for x in range(size_x):
matrix[x, 0] = x
for y in range(size_y):
matrix[0, y] = y
for x in range(1, size_x):
for y in range(1, size_y):
if string_a[x - 1] == string_b[y - 1]:
matrix[x, y] = min(
matrix[x - 1, y] + 1,
matrix[x - 1, y - 1],
matrix[x, y - 1] + 1
)
else:
if name == 'Levenshtein':
matrix[x, y] = min(
matrix[x - 1, y] + 1,
matrix[x - 1, y - 1] + 1,
matrix[x, y - 1] + 1
)
else: # Canonical
matrix[x, y] = min(
matrix[x - 1, y] + 1,
matrix[x - 1, y - 1] + 2,
matrix[x, y - 1] + 1
)
return matrix[size_x - 1, size_y - 1]
class CharFuncs(object):
def __init__(self, char_meta_fname):
self.data = self.load_char_meta(char_meta_fname)
self.char_dict = dict([(c, 0) for c in self.data])
self.safe = {'\u2ff0': 'A',
# to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same.
'\u2ff1': 'B',
'\u2ff2': 'C',
'\u2ff3': 'D',
'\u2ff4': 'E',
'\u2ff5': 'F',
'\u2ff6': 'G',
'\u2ff7': 'H',
'\u2ff8': 'I',
'\u2ff9': 'J',
'\u2ffa': 'L',
'\u2ffb': 'M', }
@staticmethod
def load_char_meta(fname):
data = {}
f = open(fname, 'r', encoding='utf-8')
for line in f:
items = line.strip().split('\t')
code_point = items[0]
char = items[1]
pronunciation = items[2]
decompositions = items[3:]
assert char not in data
data[char] = {"code_point": code_point, "pronunciation": pronunciation, "decompositions": decompositions}
return data
def shape_distance(self, char1, char2, safe=True, as_tree=False):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.shape_distance('田', '由')
1
>>> c.shape_distance('牛', '午')
1
"""
assert char1 in self.data
assert char2 in self.data
def safe_encode(decomp):
tree = ''
for c in string_to_tree(decomp):
if c not in self.safe:
tree += c
else:
tree += self.safe[c]
return tree
def safe_encode_string(decomp):
tree = ''
for c in decomp:
if c not in self.safe:
tree += c
else:
tree += self.safe[c]
return tree
decomps_1 = self.data[char1]["decompositions"]
decomps_2 = self.data[char2]["decompositions"]
distance = 1e5
if as_tree:
for decomp1 in decomps_1:
for decomp2 in decomps_2:
if not safe:
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
else:
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
distance = min(distance, ted)
else:
for decomp1 in decomps_1:
for decomp2 in decomps_2:
if not safe:
ed = edit_distance(decomp1, decomp2)
else:
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
distance = min(distance, ed)
return distance
def pronunciation_distance(self, char1, char2):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.pronunciation_distance('田', '由')
3.4
>>> c.pronunciation_distance('牛', '午')
2.6
"""
assert char1 in self.data
assert char2 in self.data
pronunciations1 = self.data[char1]["pronunciation"]
pronunciations2 = self.data[char2]["pronunciation"]
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
return 0.0
else:
pronunciations1 = pronunciations1.split(';') # separate by lan
pronunciations2 = pronunciations2.split(';') # separate by lan
distance = 0.0
count = 0
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
pass
else:
distance_lan = 1e5
for p1 in pron_lan1.split(','):
for p2 in pron_lan2.split(','):
distance_lan = min(distance_lan, edit_distance(p1, p2))
distance += distance_lan
count += 1
return distance / count
@staticmethod
def load_dict(fname):
data = {}
f = open(fname, 'r', encoding='utf-8')
for line in f:
char, freq = line.strip().split('\t')
assert char not in data
data[char] = freq
return data
def similarity(self, char1, char2, weights=(0.8, 0.2, 0.0), as_tree=False):
"""
this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1.
"""
# assert char1 in self.char_dict
# assert char2 in self.char_dict
shape_w, sound_w, freq_w = weights
if char1 in self.char_dict and char2 in self.char_dict:
shape_sim = self.shape_similarity(char1, char2, as_tree=as_tree)
sound_sim = self.pronunciation_similarity(char1, char2)
freq_sim = 1.0 - self.char_dict[char2] / len(self.char_dict)
return shape_sim * shape_w + sound_sim * sound_w + freq_sim * freq_w
else:
return 0.0
def shape_similarity(self, char1, char2, safe=True, as_tree=False):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.shape_similarity('牛', '午')
0.8571428571428572
>>> c.shape_similarity('田', '由')
0.8888888888888888
"""
assert char1 in self.data
assert char2 in self.data
def safe_encode(decomp):
tree = ''
for c in string_to_tree(decomp):
if c not in self.safe:
tree += c
else:
tree += self.safe[c]
return tree
def safe_encode_string(decomp):
tree = ''
for c in decomp:
if c not in self.safe:
tree += c
else:
tree += self.safe[c]
return tree
decomps_1 = self.data[char1]["decompositions"]
decomps_2 = self.data[char2]["decompositions"]
similarity = 0.0
if as_tree:
for decomp1 in decomps_1:
for decomp2 in decomps_2:
if not safe:
ted = tree_edit_distance(string_to_tree(decomp1), string_to_tree(decomp2))
else:
ted = tree_edit_distance(safe_encode(decomp1), safe_encode(decomp2))
normalized_ted = 2 * ted / (len(decomp1) + len(decomp2) + ted)
similarity = max(similarity, 1 - normalized_ted)
else:
for decomp1 in decomps_1:
for decomp2 in decomps_2:
if not safe:
ed = edit_distance(decomp1, decomp2)
else:
ed = edit_distance(safe_encode_string(decomp1), safe_encode_string(decomp2))
normalized_ed = ed / max(len(decomp1), len(decomp2))
similarity = max(similarity, 1 - normalized_ed)
return similarity
def pronunciation_similarity(self, char1, char2):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.pronunciation_similarity('牛', '午')
0.27999999999999997
>>> c.pronunciation_similarity('由', '田')
0.09
"""
assert char1 in self.data
assert char2 in self.data
pronunciations1 = self.data[char1]["pronunciation"]
pronunciations2 = self.data[char2]["pronunciation"]
if pronunciations1[0] == 'null' or pronunciations2 == 'null':
return 0.0
else:
pronunciations1 = pronunciations1.split(';') # separate by lan
pronunciations2 = pronunciations2.split(';') # separate by lan
similarity = 0.0
count = 0
for pron_lan1, pron_lan2 in zip(pronunciations1, pronunciations2):
if (pron_lan1 == 'null') or (pron_lan2 == 'null'):
pass
else:
similarity_lan = 0.0
for p1 in pron_lan1.split(','):
for p2 in pron_lan2.split(','):
tmp_sim = 1 - edit_distance(p1, p2) / max(len(p1), len(p2))
similarity_lan = max(similarity_lan, tmp_sim)
similarity += similarity_lan
count += 1
return similarity / count if count else 0
def string_to_tree(string):
"""
This function converts ids string to a string that can be used as a tree input to APTED.
Any Error raised by this function implies that the input string is invalid.
>>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎
'{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}'
>>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全
'{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}'
>>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金
'{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}'
>>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車
'{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
>>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東
'{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
>>> string_to_tree('丿') # 丿
'{丿}'
>>> string_to_tree('⿻') # ⿻
'{⿻}'
"""
if string[0] in IDCS and len(string) != 1:
bracket_stack = []
tree = []
def add_brackets(num):
if num == 2:
bracket_stack.extend(['}', '{', '}'])
else:
bracket_stack.extend(['}', '{', '}', '{', '}'])
tree.append('{')
global_just_put = '{'
for c in string:
tree.append(c)
if c in IDCS:
assert global_just_put != '}'
add_brackets(IDCS[c])
global_just_put = '{'
else:
just_put = ''
while just_put != '{' and bracket_stack:
just_put = bracket_stack.pop(-1)
tree.append(just_put)
global_just_put = just_put
res = ''.join(tree)
assert res[-1] == '}'
else:
assert len(string) == 1 or string == 'null'
res = string[0]
return '{' + res + '}'
def pinyin_map(standard_pinyin):
"""
>>> pinyin_map('xuě')
'xue3'
>>> pinyin_map('xue')
'xue'
>>> pinyin_map('lǜ')
'lü4'
>>> pinyin_map('fá')
'fa2'
"""
tone = ''
pinyin = ''
assert ' ' not in standard_pinyin
for c in standard_pinyin:
if c in PINYIN:
pinyin += PINYIN[c][0]
assert tone == ''
tone = str(PINYIN[c][1])
else:
pinyin += c
pinyin += tone
return pinyin
def parse_args():
usage = '\n1. You can compute character similarity by:\n' \
'python char_sim.py 午 牛 年 千\n' \
'\n' \
'2. You can use ted in computing character similarity by:\n' \
'python char_sim.py 午 牛 年 千 -t\n' \
'\n'
parser = argparse.ArgumentParser(
description='A script to compute Chinese character (Kanji) similarity', usage=usage)
parser.add_argument('multiargs', nargs='*', type=str, default=None,
help='Chinese characters in question')
parser.add_argument('--ted', '-t', action="store_true", default=False,
help='True=to use tree edit distence (TED)'
'False=to use string edit distance')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
c = CharFuncs('data/char_meta.txt')
if not args.ted:
for i, c1 in enumerate(args.multiargs):
for c2 in args.multiargs[i:]:
if c1 != c2:
print(f'For character pair ({c1}, {c2}):')
print(f' v-sim = {c.shape_similarity(c1, c2)}')
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
else:
for i, c1 in enumerate(args.multiargs):
for c2 in args.multiargs[i:]:
if c1 != c2:
print(f'For character pair ({c1}, {c2}):')
print(f' v-sim = {c.shape_similarity(c1, c2, as_tree=True)}')
print(f' p-sim = {c.pronunciation_similarity(c1, c2)}\n')
\ No newline at end of file
import argparse
from collections import Counter
def main():
# Parse command line args
args = parse_args()
# Open hypothesis and reference m2 files and split into chunks
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
# Make sure they have the same number of sentences
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
# Store global corpus level best counts here
best_dict = Counter({"tp":0, "fp":0, "fn":0})
best_cats = {}
# Process each sentence
sents = zip(hyp_m2, ref_m2)
for sent_id, sent in enumerate(sents):
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
src = sent[0].split("\n")[0]
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
ref_edits = simplify_edits(sent[1], args.max_answer_num)
# Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args)
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(src,
hyp_dict, ref_dict, best_dict, sent_id, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
# Print results
print_results(best_dict, best_cats, args)
# Parse command line args
def parse_args():
parser = argparse.ArgumentParser(
description="Calculate F-scores for error detection and/or correction.\n"
"Flags let you evaluate at different levels of granularity.",
formatter_class=argparse.RawTextHelpFormatter,
usage="%(prog)s [options] -hyp HYP -ref REF")
parser.add_argument(
"-hyp",
help="A hypothesis M2 file.",
required=True)
parser.add_argument(
"-ref",
help="A reference M2 file.",
required=True)
parser.add_argument(
"--start",
type=int,
default=None
)
parser.add_argument(
"--end",
type=int,
default=None
)
parser.add_argument(
"--max_answer_num",
type=int,
default=None
)
parser.add_argument(
"--reference_num",
type=int,
default=None
)
parser.add_argument(
"-b",
"--beta",
help="Value of beta in F-score. (default: 0.5)",
default=0.5,
type=float)
parser.add_argument(
"-v",
"--verbose",
help="Print verbose output.",
action="store_true")
eval_type = parser.add_mutually_exclusive_group()
eval_type.add_argument(
"-dt",
help="Evaluate Detection in terms of Tokens.",
action="store_true")
eval_type.add_argument(
"-ds",
help="Evaluate Detection in terms of Spans.",
action="store_true")
eval_type.add_argument(
"-cs",
help="Evaluate Correction in terms of Spans. (default)",
action="store_true")
eval_type.add_argument(
"-cse",
help="Evaluate Correction in terms of Spans and Error types.",
action="store_true")
parser.add_argument(
"-single",
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
action="store_true")
parser.add_argument(
"-multi",
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
action="store_true")
parser.add_argument(
"-multi_hyp_avg",
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
action="store_true") # For IAA calculation
parser.add_argument(
"-multi_hyp_max",
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
action="store_true") # For multiple hypotheses system evaluation
parser.add_argument(
"-filt",
help="Do not evaluate the specified error types.",
nargs="+",
default=[])
parser.add_argument(
"-cat",
help="Show error category scores.\n"
"1: Only show operation tier scores; e.g. R.\n"
"2: Only show main tier scores; e.g. NOUN.\n"
"3: Show all category scores; e.g. R:NOUN.",
choices=[1, 2, 3],
type=int)
args = parser.parse_args()
return args
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def simplify_edits(sent, max_answer_num):
out_edits = []
# Get the edit lines from an m2 block.
edits = sent.split("\n")
# Loop through the edits
for edit in edits:
# Preprocessing
if edit.startswith("A "):
edit = edit[2:].split("|||") # Ignore "A " then split.
span = edit[0].split()
start = int(span[0])
end = int(span[1])
cat = edit[1]
cor = edit[2].replace(" ", "")
coder = int(edit[-1])
out_edit = [start, end, cat, cor, coder]
out_edits.append(out_edit)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if max_answer_num is None:
return out_edits
elif max_answer_num == 1:
return [edit for edit in out_edits if edit[-1] == 0]
elif max_answer_num == 2:
return [edit for edit in out_edits if edit[-1] in [0, 1]]
elif max_answer_num == 3:
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
def process_edits(edits, args):
coder_dict = {}
# Add an explicit noop edit if there are no edits.
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
# Loop through the edits
for edit in edits:
# Name the edit elements for clarity
start = edit[0]
end = edit[1]
cat = edit[2]
cor = edit[3]
coder = edit[4]
# Add the coder to the coder_dict if necessary
if coder not in coder_dict: coder_dict[coder] = {}
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
if not args.dt and not args.ds and cat == "UNK": continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
# 4. If there is a filter, ignore the specified error types
if args.filt and cat in args.filt: continue
# Token Based Detection
if args.dt:
# Preserve noop edits.
if start == -1:
if (start, start) in coder_dict[coder].keys():
coder_dict[coder][(start, start)].append(cat)
else:
coder_dict[coder][(start, start)] = [cat]
# Insertions defined as affecting the token on the right
elif start == end and start >= 0:
if (start, start+1) in coder_dict[coder].keys():
coder_dict[coder][(start, start+1)].append(cat)
else:
coder_dict[coder][(start, start+1)] = [cat]
# Edit spans are split for each token in the range.
else:
for tok_id in range(start, end):
if (tok_id, tok_id+1) in coder_dict[coder].keys():
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
else:
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
# Span Based Detection
elif args.ds:
if (start, end) in coder_dict[coder].keys():
coder_dict[coder][(start, end)].append(cat)
else:
coder_dict[coder][(start, end)] = [cat]
# Span Based Correction
else:
# With error type classification
if args.cse:
if (start, end, cat, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cat, cor)].append(cat)
else:
coder_dict[coder][(start, end, cat, cor)] = [cat]
# Without error type classification
else:
if (start, end, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cor)].append(cat)
else:
coder_dict[coder][(start, end, cor)] = [cat]
return coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
best_cat = {}
# skip not annotatable sentence
if len(ref_dict.keys()) == 1:
ref_id = list(ref_dict.keys())[0]
if len(ref_dict[ref_id].keys()) == 1:
cat = list(ref_dict[ref_id].values())[0][0]
if cat == "NA":
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Compare each hyp and ref combination
for hyp_id in hyp_dict.keys():
for ref_id in ref_dict.keys():
# Get the local counts for the current combination.
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
# Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
if (f > best_f) or \
(f == best_f and tp > best_tp) or \
(f == best_f and tp == best_tp and fp < best_fp) or \
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
best_tp, best_fp, best_fn = tp, fp, fn
best_f, best_hyp, best_ref = f, hyp_id, ref_id
best_cat = cat_dict
# Verbose output
if args.verbose:
# Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys()))
# Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
# Print verbose info
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+src[1:])
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
print("HYPOTHESIS EDITS :", hyp_verb)
print("REFERENCE EDITS :", ref_verb)
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
# Verbose output: display the best hyp+ref combination
if args.verbose:
print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives
fp = 0 # False Positives
fn = 0 # False Negatives
cat_dict = {} # {cat: [tp, fp, fn], ...}
for h_edit, h_cats in hyp_edits.items():
# noop hyp edits cannot be TP or FP
if h_cats[0] == "noop": continue
# TRUE POSITIVES
if h_edit in ref_edits.keys():
# On occasion, multiple tokens at same span.
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
tp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][0] += 1
else:
cat_dict[h_cat] = [1, 0, 0]
# FALSE POSITIVES
else:
# On occasion, multiple tokens at same span.
for h_cat in h_cats:
fp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][1] += 1
else:
cat_dict[h_cat] = [0, 1, 0]
for r_edit, r_cats in ref_edits.items():
# noop ref edits cannot be FN
if r_cats[0] == "noop": continue
# FALSE NEGATIVES
if r_edit not in hyp_edits.keys():
# On occasion, multiple tokens at same span.
for r_cat in r_cats:
fn += 1
# Each dict value [TP, FP, FN]
if r_cat in cat_dict.keys():
cat_dict[r_cat][2] += 1
else:
cat_dict[r_cat] = [0, 0, 1]
return tp, fp, fn, cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def computeFScore(tp, fp, fn, beta):
p = float(tp)/(tp+fp) if fp else 1.0
r = float(tp)/(tp+fn) if fn else 1.0
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2):
for cat, stats in dict2.items():
if cat in dict1.keys():
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
else:
dict1[cat] = stats
return dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def processCategories(cat_dict, setting):
# Otherwise, do some processing.
proc_cat_dict = {}
for cat, cnt in cat_dict.items():
if cat == "UNK":
proc_cat_dict[cat] = cnt
continue
# M, U, R or UNK combined only.
if setting == 1:
if cat[0] in proc_cat_dict.keys():
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
else:
proc_cat_dict[cat[0]] = cnt
# Everything without M, U or R.
elif setting == 2:
if cat[2:] in proc_cat_dict.keys():
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
else:
proc_cat_dict[cat[2:]] = cnt
# All error category combinations
else:
return cat_dict
return proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
def print_results(best, best_cats, args):
# Prepare output title.
if args.dt: title = " Token-Based Detection "
elif args.ds: title = " Span-Based Detection "
elif args.cse: title = " Span-Based Correction + Classification "
else: title = " Span-Based Correction "
# Category Scores
if args.cat:
best_cats = processCategories(best_cats, args.cat)
print("")
print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
print("")
if __name__ == "__main__":
# Run the program
main()
import re
from ..utils.rc_f1 import CJRCEvaluator
"""
given a target substring. find its all occurances in the string s
return the starting and ending index of every occurance
"""
def __find_substring_starts(s, target):
return [(m.start(), m.end()) for m in re.finditer(target, s)]
"""
compute the reading comprehension F1 scores
hyps and refs are lists of hyposisis and reference strings
"""
def compute_rc_f1(hyps, refs):
scores = 0
for h, r in zip(hyps, refs):
scores += CJRCEvaluator.compute_f1(r, h)
return {'score': scores / len(hyps)}
"""
compute the information extraction F1 scores
hyps and refs are lists of hyposisis and reference strings
entity_types: a set of all possible entity types
"""
def compute_ie_f1(hyps, refs, entity_types):
assert (len(hyps) == len(refs))
scores, abstentions = 0, 0
for h, r in zip(hyps, refs):
h = __extract_entities_pred(h, entity_types)
r = __extract_entities_ref(r)
if r == {}:
scores += 1 if h == {} else 0
continue
if h == {}:
abstentions += 1
intersected = [CJRCEvaluator.compute_f1(r[etype], einstance) for etype, einstance in h.items() if etype in r]
prec = sum(intersected) / len(h) if len(h) > 0 else 0
rec = sum(intersected) / len(r) if len(r) > 0 else 0
# print(prec, rec, intersected)
scores += 2 * prec * rec / (prec + rec + 1e-10)
return {'score': scores / len(hyps), "anstention_rate": abstentions / len(hyps)}
def __extract_entities_ref(ref):
outputs = {}
if ref.strip() == '':
return outputs
for seg in ref.split(';'):
seg = seg.split(':')
outputs[seg[0]] = seg[1]
return outputs
"""
extract entity type and instances from the model prediction
pred: string of model prediction
entity_types: a set of all possible entity types
"""
def __extract_entities_pred(pred, entity_types):
outputs = {}
for etype in entity_types:
occurances = __find_substring_starts(pred, etype)
for start, end in occurances:
if end >= (len(pred) - 2):
continue
if pred[end] == ":" or pred[end] == ":":
einstance = re.split("\n| ", pred[end + 1:].strip())[0].strip()
if einstance != '无' and einstance != '未提及':
outputs[etype] = einstance
return outputs
from rouge_chinese import Rouge
import jieba
from nltk.translate.gleu_score import corpus_gleu
def compute_f1_two_sets(pred_set, gt_set):
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return f1
def multi_choice_judge(prediction, option_list, answer_token):
# a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict, abstention, accuracy = {}, 0, 0
for option in option_list:
option_count = prediction.count(option)
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
if sum(count_dict.values()) == 0:
abstention = 1
# if the answer token is the only predicted token, the prediction is correct
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
accuracy = 1
return {"score": accuracy, "abstention": abstention}
"""
compute the rouge score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def compute_rouge(hyps, refs):
assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [' '.join(jieba.cut(r)) for r in refs]
return Rouge().get_scores(hyps, refs)
"""
compute the gleu score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def compute_gleu(hyps, refs):
assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [[' '.join(jieba.cut(r))] for r in refs]
return corpus_gleu(refs, hyps)
import numpy as np
from typing import List, Tuple, Dict
from modules.tokenizer import Tokenizer
import os
from string import punctuation
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct = punctuation
punct = chinese_punct + english_punct
def check_all_chinese(word):
"""
判断一个单词是否全部由中文组成
:param word:
:return:
"""
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
def read_cilin():
"""
Cilin 詞林 is a thesaurus with semantic information
"""
# TODO -- fix this path
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
semantic_dict = {}
semantic_classes = {}
for line in lines:
code, *words = line.split(" ")
for word in words:
semantic_dict[word] = code
# make reverse dict
if code in semantic_classes:
semantic_classes[code] += words
else:
semantic_classes[code] = words
return semantic_dict, semantic_classes
def read_confusion():
confusion_dict = {}
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
for line in f:
li = line.rstrip('\n').split(" ")
confusion_dict[li[0]] = li[1:]
return confusion_dict
class Alignment:
"""
对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作
"""
def __init__(
self,
semantic_dict: Dict,
confusion_dict: Dict,
granularity: str = "word",
) -> None:
"""
构造函数
:param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集
"""
self.insertion_cost = 1
self.deletion_cost = 1
self.semantic_dict = semantic_dict
self.confusion_dict = confusion_dict
# Because we use character level tokenization, this doesn't currently use POS
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
self.granularity = granularity # word-level or character-level
self.align_seqs = []
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
verbose: bool = False):
cost_matrix, oper_matrix = self.align(src, tgt)
align_seq = self.get_cheapest_align_seq(oper_matrix)
if verbose:
print("========== Seg. and POS: ==========")
print(src)
print(tgt)
print("========== Cost Matrix ==========")
print(cost_matrix)
print("========== Oper Matrix ==========")
print(oper_matrix)
print("========== Alignment ==========")
print(align_seq)
print("========== Results ==========")
for a in align_seq:
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
return align_seq
def _get_semantic_class(self, word):
"""
NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别)
"""
if word in self.semantic_dict:
code = self.semantic_dict[word]
high, mid, low = code[0], code[1], code[2:4]
return high, mid, low
else: # unknown
return None
@staticmethod
def _get_class_diff(a_class, b_class):
"""
d == 3 for equivalent semantics
d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距
"""
d = sum([a == b for a, b in zip(a_class, b_class)])
return d
def _get_semantic_cost(self, a, b):
"""
计算基于语义信息的替换操作cost
:param a: 单词a的语义类别
:param b: 单词b的语义类别
:return: 替换编辑代价
"""
a_class = self._get_semantic_class(a)
b_class = self._get_semantic_class(b)
# unknown class, default to 1
if a_class is None or b_class is None:
return 4
elif a_class == b_class:
return 0
else:
return 2 * (3 - self._get_class_diff(a_class, b_class))
def _get_pos_cost(self, a_pos, b_pos):
"""
计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性
:param b_pos: 单词b的词性
:return: 替换编辑代价
"""
if a_pos == b_pos:
return 0
elif a_pos in self._open_pos and b_pos in self._open_pos:
return 0.25
else:
return 0.499
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
"""
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost
"""
if not (check_all_chinese(a) and check_all_chinese(b)):
return 0.5
if len(a) > len(b):
a, b = b, a
pinyin_a, pinyin_b = pinyin_b, pinyin_a
if a == b:
return 0
else:
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
"""
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a
:param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音
:return: 替换操作cost
"""
count = 0
for i in range(len(a)):
for j in range(len(b)):
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
count += 1
break
return (len(a) - count) / (len(a) * 2)
def get_sub_cost(self, a_seg, b_seg):
"""
Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
"""
if a_seg[0] == b_seg[0]:
return 0
if self.granularity == "word": # 词级别可以额外利用词性信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + pos_cost + char_cost
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
if a_seg[0] in punct and b_seg[0] in punct:
pos_cost = 0.0
elif a_seg[0] not in punct and b_seg[0] not in punct:
pos_cost = 0.25
else:
pos_cost = 0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + char_cost + pos_cost
def align(self,
src: List[Tuple],
tgt: List[Tuple]):
"""
Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题
"""
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
oper_matrix = np.full(
(len(src) + 1, len(tgt) + 1), "O", dtype=object
) # 操作矩阵
# Fill in the edges
for i in range(1, len(src) + 1):
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
oper_matrix[i][0] = ["D"]
for j in range(1, len(tgt) + 1):
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
oper_matrix[0][j] = ["I"]
# Loop through the cost matrix
for i in range(len(src)):
for j in range(len(tgt)):
# Matches
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
oper_matrix[i + 1][j + 1] = ["M"]
# Non-matches
else:
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
src[i], tgt[j]
) # 由替换动作得到的总cost
# Calculate transposition cost
# 计算移位操作的总cost
trans_cost = float("inf")
k = 1
while (
i - k >= 0
and j - k >= 0
and cost_matrix[i - k + 1][j - k + 1]
!= cost_matrix[i - k][j - k]
):
p1 = sorted([a[0] for a in src][i - k: i + 1])
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
if p1 == p2:
trans_cost = cost_matrix[i - k][j - k] + k
break
k += 1
costs = [trans_cost, sub_cost, ins_cost, del_cost]
ind = costs.index(min(costs))
cost_matrix[i + 1][j + 1] = costs[ind]
# ind = costs.index(costs[ind], ind+1)
for idx, cost in enumerate(costs):
if cost == costs[ind]:
if idx == 0:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
else:
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
elif idx == 1:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["S"]
else:
oper_matrix[i + 1][j + 1].append("S")
elif idx == 2:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["I"]
else:
oper_matrix[i + 1][j + 1].append("I")
else:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["D"]
else:
oper_matrix[i + 1][j + 1].append("D")
return cost_matrix, oper_matrix
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
"""
深度优先遍历,获取最小编辑距离相同的所有序列
"""
if i + j == 0:
self.align_seqs.append(align_seq_now)
else:
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if strategy != "all": ops = ops[:1]
for op in ops:
if op in {"M", "S"}:
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
elif op == "D":
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
elif op == "I":
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
else:
k = int(op[1:])
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
def get_cheapest_align_seq(self, oper_matrix):
"""
回溯获得编辑距离最小的编辑序列
"""
self.align_seqs = []
i = oper_matrix.shape[0] - 1
j = oper_matrix.shape[1] - 1
if abs(i - j) > 10:
self._dfs(i, j , [], oper_matrix, "first")
else:
self._dfs(i, j , [], oper_matrix, "all")
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
return final_align_seqs
if __name__ == "__main__":
tokenizer = Tokenizer("word")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
src, tgt = tokenizer(sents)
alignment(src, tgt, verbose=True)
\ No newline at end of file
from typing import List, Tuple
from modules.alignment import read_cilin, read_confusion, Alignment
from modules.merger import Merger
from modules.classifier import Classifier
class Annotator:
def __init__(self,
align: Alignment,
merger: Merger,
classifier: Classifier,
granularity: str = "word",
strategy: str = "first"):
self.align = align
self.merger = merger
self.classifier = classifier
self.granularity = granularity
self.strategy = strategy
@classmethod
def create_default(cls, granularity: str = "word", strategy: str = "first"):
"""
Default parameters used in the paper
"""
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
align = Alignment(semantic_dict, confusion_dict, granularity)
merger = Merger(granularity)
classifier = Classifier(granularity)
return cls(align, merger, classifier, granularity, strategy)
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
annotator_id: int = 0,
verbose: bool = False):
"""
Align sentences and annotate them with error type information
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
src_str = "".join(src_tokens)
tgt_str = "".join(tgt_tokens)
# convert to text form
annotations_out = ["S " + " ".join(src_tokens) + "\n"]
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
annotations_out.append(f"T{annotator_id} 没有错误\n")
cors = [tgt_str]
op, toks, inds = "noop", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
elif tgt_str == "无法标注": # Not Annotatable Case
annotations_out.append(f"T{annotator_id} 无法标注\n")
cors = [tgt_str]
op, toks, inds = "NA", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
else: # Other
align_objs = self.align(src, tgt)
edit_objs = []
align_idx = 0
if self.strategy == "first":
align_objs = align_objs[:1]
for align_obj in align_objs:
edits = self.merger(align_obj, src, tgt, verbose)
if edits not in edit_objs:
edit_objs.append(edits)
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
align_idx += 1
cors = self.classifier(src, tgt, edits, verbose)
# annotations_out = []
for cor in cors:
op, toks, inds = cor.op, cor.toks, cor.inds
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
annotations_out.append("\n")
return annotations_out, cors
from char_smi import CharFuncs
from collections import namedtuple
from pypinyin import pinyin, Style
import os
Correction = namedtuple(
"Correction",
[
"op",
"toks",
"inds",
],
)
file_path = os.path.dirname(os.path.abspath(__file__))
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
def check_spell_error(src_span: str,
tgt_span: str,
threshold: float = 0.8) -> bool:
if len(src_span) != len(tgt_span):
return False
src_chars = [ch for ch in src_span]
tgt_chars = [ch for ch in tgt_span]
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
return True
for src_char, tgt_char in zip(src_chars, tgt_chars):
if src_char != tgt_char:
if src_char not in char_smi.data or tgt_char not in char_smi.data:
return False
v_sim = char_smi.shape_similarity(src_char, tgt_char)
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
if v_sim + p_sim < threshold and not (
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
return False
return True
class Classifier:
"""
错误类型分类器
"""
def __init__(self,
granularity: str = "word"):
self.granularity = granularity
@staticmethod
def get_pos_type(pos):
if pos in {"n", "nd"}:
return "NOUN"
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
return "NOUN-NE"
if pos in {"v"}:
return "VERB"
if pos in {"a", "b"}:
return "ADJ"
if pos in {"c"}:
return "CONJ"
if pos in {"r"}:
return "PRON"
if pos in {"d"}:
return "ADV"
if pos in {"u"}:
return "AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX"
if pos in {"m"}:
return "NUM"
if pos in {"p"}:
return "PREP"
if pos in {"q"}:
return "QUAN"
if pos in {"wp"}:
return "PUNCT"
return "OTHER"
def __call__(self,
src,
tgt,
edits,
verbose: bool = False):
"""
为编辑操作划分错误类型
:param src: 错误句子信息
:param tgt: 正确句子信息
:param edits: 编辑操作
:param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作
"""
results = []
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
for edit in edits:
error_type = edit[0]
src_span = " ".join(src_tokens[edit[1]: edit[2]])
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
# print(tgt_span)
cor = None
if error_type[0] == "T":
cor = Correction("W", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "D":
if self.granularity == "word": # 词级别可以细分错误类型
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
else:
pos = self.get_pos_type(src[edit[1]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("R", "-NONE-", (edit[1], edit[2]))
elif error_type[0] == "I":
if self.granularity == "word": # 词级别可以细分错误类型
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
else:
pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("M", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "S":
if self.granularity == "word": # 词级别可以细分错误类型
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
# Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else:
# pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else:
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
else:
pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("S", tgt_span, (edit[1], edit[2]))
results.append(cor)
if verbose:
print("========== Corrections ==========")
for cor in results:
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
return results
# print(pinyin("朝", style=Style.NORMAL))
from itertools import groupby
from string import punctuation
from typing import List
from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein
class Merger:
"""
合并编辑操作,从Token-Level转换为Span-Level
"""
def __init__(self,
granularity: str = "word",
merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity
self.merge = merge
@staticmethod
def _merge_edits(seq, tag="X"):
if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else:
return seq
@staticmethod
def _check_revolve(span_a, span_b):
span_a = span_a + span_a
return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1:
return seq
ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt
return seq
if set(ops) == {"S"}:
if self.granularity == "word":
return seq
else:
return self._merge_edits(seq, "S")
if set(ops) == {"M"}:
return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S")
def __call__(self,
align_obj,
src: List,
tgt: List,
verbose: bool = False):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
edits = []
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby(
align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
):
group = list(group)
# T is always split TODO: Evaluate this
if op == "T":
for seq in group:
edits.append(seq)
# Process D, I and S subsequence
else:
# Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed:
edits.append(seq)
filtered_edits = []
i = 0
while i < len(edits):
e1 = edits[i][0][0]
if i < len(edits) - 2:
e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b):
a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1:
if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2:
if edit[0] == "S":
b = True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str = ""
tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]):
break
common_str += src_tokens[i]
tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3]
if common_str:
tmp_str = ""
for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i]
if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False
break
elif len(tmp_str) > len(common_str):
break
# 尾部有重叠
common_str = ""
new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str):
break
common_str = src_tokens[i] + common_str
tmp_new_end_1 = i
if common_str:
tmp_str = ""
for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i
b = False
break
elif len(tmp_str) > len(common_str):
break
if b:
second_filter.append(edit)
else:
if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit)
else:
second_filter.append(edit)
if verbose:
print("========== Parallels ==========")
print("".join(src_tokens))
print("".join(tgt_tokens))
print("========== Results ==========")
for edit in second_filter:
op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========")
print(str(src))
print(str(tgt))
return second_filter
if __name__ == "__main__":
tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")]
src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt)
m = Merger()
m(align_obj, src, tgt, verbose=True)
\ No newline at end of file
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import unicodedata
import six
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
if item not in vocab:
print("warning: %s not in vocab" % item)
item = "[UNK]"
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
# output_tokens.append(self.unk_token)
output_tokens.append(token) # keep the UNK token
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
\ No newline at end of file
from ltp import LTP
from typing import List
from pypinyin import pinyin, Style, lazy_pinyin
import torch
import os
import functools
class Tokenizer:
"""
分词器
"""
def __init__(self,
granularity: str = "word",
device: str = "cpu",
segmented: bool = False,
bpe: bool = False,
) -> None:
"""
构造函数
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
"""
self.ltp = None
if granularity == "word":
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
self.ltp.add_words(words=["[缺失成分]"], max_window=6)
self.segmented = segmented
self.granularity = granularity
if self.granularity == "word":
self.tokenizer = self.split_word
elif self.granularity == "char":
self.tokenizer = functools.partial(self.split_char, bpe=bpe)
else:
raise NotImplementedError
def __repr__(self) -> str:
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
def __call__(self,
input_strings: List[str]
) -> List:
"""
分词函数
:param input_strings: 需要分词的字符串列表
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
"""
if not self.segmented:
input_strings = ["".join(s.split(" ")) for s in input_strings]
results = self.tokenizer(input_strings)
return results
def split_char(self, input_strings: List[str], bpe=False) -> List:
"""
分字函数
:param input_strings: 需要分字的字符串
:return: 分字结果
"""
if bpe:
from . import tokenization
project_dir = os.path.dirname(os.path.dirname(__file__))
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
results = []
for input_string in input_strings:
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
else:
segment_string = input_string
# print(segment_string)
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
return results
def split_word(self, input_strings: List[str]) -> List:
"""
分词函数
:param input_strings: 需要分词的字符串
:return: 分词结果
"""
if self.segmented:
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
else:
seg, hidden = self.ltp.seg(input_strings)
pos = self.ltp.pos(hidden)
result = []
for s, p in zip(seg, pos):
pinyin = [lazy_pinyin(word) for word in s]
result.append(list(zip(s, p, pinyin)))
return result
if __name__ == "__main__":
tokenizer = Tokenizer("word")
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
import os
from modules.annotator import Annotator
from modules.tokenizer import Tokenizer
import argparse
from collections import Counter
from tqdm import tqdm
import torch
from collections import defaultdict
from multiprocessing import Pool
from opencc import OpenCC
import timeout_decorator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
annotator, sentence_to_tokenized = None, None
cc = OpenCC("t2s")
@timeout_decorator.timeout(10)
def annotate_with_time_out(line):
"""
:param line:
:return:
"""
sent_list = line.split("\t")[1:]
source = sent_list[0]
if args.segmented:
source = source.strip()
else:
source = "".join(source.strip().split())
output_str = ""
for idx, target in enumerate(sent_list[1:]):
try:
if args.segmented:
target = target.strip()
else:
target = "".join(target.strip().split())
if not args.no_simplified:
target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0:
output_str += "".join(out[:-1])
else:
output_str += "".join(out[1:-1])
except Exception:
raise Exception
return output_str
def annotate(line):
"""
:param line:
:return:
"""
sent_list = line.split("\t")[1:]
source = sent_list[0]
if args.segmented:
source = source.strip()
else:
source = "".join(source.strip().split())
output_str = ""
for idx, target in enumerate(sent_list[1:]):
try:
if args.segmented:
target = target.strip()
else:
target = "".join(target.strip().split())
if not args.no_simplified:
target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0:
output_str += "".join(out[:-1])
else:
output_str += "".join(out[1:-1])
except Exception:
raise Exception
return output_str
def firsttime_process(args):
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
# error_types = []
with open(args.output, "w", encoding="utf-8") as f:
count = 0
sentence_set = set()
sentence_to_tokenized = {}
for line in lines:
sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list):
if args.segmented:
# print(sent)
sent = sent.strip()
else:
sent = "".join(sent.split()).strip()
if idx >= 1:
if not args.no_simplified:
sentence_set.add(cc.convert(sent))
else:
sentence_set.add(sent)
else:
sentence_set.add(sent)
batch = []
for sent in tqdm(sentence_set):
count += 1
if sent:
batch.append(sent)
if count % args.batch_size == 0:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
batch = []
if batch:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
timeout_indices = []
# 单进程模式
for idx, line in enumerate(tqdm(lines)):
try:
ret = annotate_with_time_out(line)
except Exception:
timeout_indices.append(idx)
return timeout_indices
def main(args):
timeout_indices = firsttime_process(args)
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
new_lines = []# format: id src tgt1 tgt2...
with open(args.output, "w", encoding="utf-8") as f:
count = 0
sentence_set = set()
sentence_to_tokenized = {}
for line_idx, line in enumerate(lines):
if line_idx in timeout_indices:
# print(f"line before split: {line}")
line_split = line.split("\t")
line_number, sent_list = line_split[0], line_split[1:]
assert len(sent_list) == 2
sent_list[-1] = " 无"
line = line_number + "\t" + "\t".join(sent_list)
# print(f"line time out: {line}")
new_lines.append(line)
else:
new_lines.append(line)
sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list):
if args.segmented:
# print(sent)
sent = sent.strip()
else:
sent = "".join(sent.split()).strip()
if idx >= 1:
if not args.no_simplified:
sentence_set.add(cc.convert(sent))
else:
sentence_set.add(sent)
else:
sentence_set.add(sent)
batch = []
for sent in tqdm(sentence_set):
count += 1
if sent:
batch.append(sent)
if count % args.batch_size == 0:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
batch = []
if batch:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
# 单进程模式
lines = new_lines
for idx, line in enumerate(tqdm(lines)):
ret = annotate(line)
f.write(ret)
f.write("\n")
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# with Pool(args.worker_num) as pool:
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# if ret:
# f.write(ret)
# f.write("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Choose input file to annotate")
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
args = parser.parse_args()
main(args)
"""Official evaluation script for CAIL-2021.
The code is based partially on CoQA evaluation script.
"""
import json
import sys
from collections import Counter
class CJRCEvaluator:
def __init__(self, gold_file):
self.gold_data = CJRCEvaluator.gold_answers_to_dict(gold_file)
@staticmethod
def gold_answers_to_dict(gold_file):
dataset = json.load(open(gold_file, mode="r", encoding="utf-8"))
gold_dict = {}
# id_to_domain = {}
for story in dataset['data']:
qas = story["paragraphs"][0]["qas"]
for qa in qas:
qid = qa['id']
gold_answers = []
answers = qa["answers"]
if len(answers) == 0:
gold_answers = ['']
else:
for answer in qa["answers"]:
if type(answer) == dict:
gold_answers.append(answer["text"])
elif type(answer) == list:
gold_answers.append("".join([a["text"] for a in answer]))
if qid in gold_dict:
sys.stderr.write("Gold file has duplicate stories: {}".format(qid))
gold_dict[qid] = gold_answers
return gold_dict
@staticmethod
def preds_to_dict(pred_file):
preds = json.load(open(pred_file, mode="r", encoding="utf-8"))
pred_dict = {}
for pred in preds:
pred_dict[pred['id']] = "".join(pred['answer'])
return pred_dict
@staticmethod
def normalize_answer(s):
"""Lower text and remove punctuation, storys and extra whitespace."""
def remove_punc(text):
return "".join(ch for ch in text if ch.isdigit() or ch.isalpha())
def lower(text):
return text.lower()
return remove_punc(lower(s))
@staticmethod
def get_tokens(s):
if not s: return []
return list(CJRCEvaluator.normalize_answer(s))
@staticmethod
def compute_exact(a_gold, a_pred):
return int(CJRCEvaluator.normalize_answer(a_gold) == CJRCEvaluator.normalize_answer(a_pred))
@staticmethod
def compute_f1(a_gold, a_pred):
gold_toks = CJRCEvaluator.get_tokens(a_gold)
pred_toks = CJRCEvaluator.get_tokens(a_pred)
common = Counter(gold_toks) & Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
@staticmethod
def _compute_turn_score(a_gold_list, a_pred):
f1_sum = 0.0
em_sum = 0.0
if len(a_gold_list) > 1:
for i in range(len(a_gold_list)):
# exclude the current answer
gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:]
em_sum += max(CJRCEvaluator.compute_exact(a, a_pred) for a in gold_answers)
f1_sum += max(CJRCEvaluator.compute_f1(a, a_pred) for a in gold_answers)
else:
em_sum += max(CJRCEvaluator.compute_exact(a, a_pred) for a in a_gold_list)
f1_sum += max(CJRCEvaluator.compute_f1(a, a_pred) for a in a_gold_list)
if f1_sum != 1:
a = 1 + 1
return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))}
def compute_turn_score(self, qid, a_pred):
''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. '''
a_gold_list = self.gold_data[qid]
return CJRCEvaluator._compute_turn_score(a_gold_list, a_pred)
def get_raw_scores(self, pred_data):
''''Returns a dict with score'''
exact_scores = {}
f1_scores = {}
for qid in self.gold_data:
if qid not in pred_data:
sys.stderr.write('Missing prediction for {}\n'.format(qid))
continue
a_pred = pred_data[qid]
scores = self.compute_turn_score(qid, a_pred)
# Take max over all gold answers
exact_scores[qid] = scores['em']
f1_scores[qid] = scores['f1']
return exact_scores, f1_scores
def get_raw_scores_human(self):
'''
Returns a dict with score
'''
exact_scores = {}
f1_scores = {}
for qid in self.gold_data:
f1_sum = 0.0
em_sum = 0.0
if len(self.gold_data[qid]) > 1:
for i in range(len(self.gold_data[qid])):
# exclude the current answer
gold_answers = self.gold_data[qid][0:i] + self.gold_data[qid][i + 1:]
em_sum += max(CJRCEvaluator.compute_exact(a, self.gold_data[qid][i]) for a in gold_answers)
f1_sum += max(CJRCEvaluator.compute_f1(a, self.gold_data[qid][i]) for a in gold_answers)
else:
exit("Gold answers should be multiple: {}={}".format(qid, self.gold_data[qid]))
exact_scores[qid] = em_sum / len(self.gold_data[qid])
f1_scores[qid] = f1_sum / len(self.gold_data[qid])
return exact_scores, f1_scores
def human_performance(self):
exact_scores, f1_scores = self.get_raw_scores_human()
return self.get_total_scores(exact_scores, f1_scores)
def model_performance(self, pred_data):
exact_scores, f1_scores = self.get_raw_scores(pred_data)
return self.get_total_scores(exact_scores, f1_scores)
def get_total_scores(self, exact_scores, f1_scores):
em_total, f1_total, turn_count = 0, 0, 0
scores = {}
for qid in self.gold_data:
em_total += exact_scores.get(qid, 0)
f1_total += f1_scores.get(qid, 0)
turn_count += 1
scores["F1"] = round(f1_total / max(1, turn_count) * 100, 1)
return scores
absl-py absl-py
accelerate>=0.19.0 accelerate>=0.19.0
boto3 boto3
cn2an
colossalai colossalai
cpm_kernels cpm_kernels
datasets>=2.12.0 datasets>=2.12.0
...@@ -9,11 +10,15 @@ fairscale ...@@ -9,11 +10,15 @@ fairscale
faiss_gpu==1.7.2 faiss_gpu==1.7.2
fuzzywuzzy fuzzywuzzy
jieba jieba
ltp
mmengine>=0.8.2 mmengine>=0.8.2
nltk==3.8 nltk==3.8
numpy==1.23.4 numpy==1.23.4
openai openai
OpenCC
pandas<2.0.0 pandas<2.0.0
pypinyin
python-Levenshtein
rank_bm25==0.2.2 rank_bm25==0.2.2
rapidfuzz rapidfuzz
requests==2.31.0 requests==2.31.0
...@@ -25,6 +30,7 @@ seaborn ...@@ -25,6 +30,7 @@ seaborn
sentence_transformers==2.2.2 sentence_transformers==2.2.2
tabulate tabulate
tiktoken tiktoken
timeout_decorator
tokenizers>=0.13.3 tokenizers>=0.13.3
torch>=1.13.1 torch>=1.13.1
tqdm==4.64.1 tqdm==4.64.1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment