Unverified Commit 4dd9a3fc authored by Leymore's avatar Leymore Committed by GitHub
Browse files

[Sync] sync with internal codes 20231019 (#488)

parent 2737249f
from rouge_chinese import Rouge from rouge_chinese import Rouge
import jieba import jieba
from nltk.translate.gleu_score import corpus_gleu from nltk.translate.gleu_score import corpus_gleu
def compute_f1_two_sets(pred_set, gt_set): def compute_f1_two_sets(pred_set, gt_set):
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0 precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0 recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return f1 return f1
def multi_choice_judge(prediction, option_list, answer_token): def multi_choice_judge(prediction, option_list, answer_token):
# a dict, key: letters in the option list, value: count of the letter in the prediction # a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict, abstention, accuracy = {}, 0, 0 count_dict, abstention, accuracy = {}, 0, 0
for option in option_list: for option in option_list:
option_count = prediction.count(option) option_count = prediction.count(option)
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1 count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
if sum(count_dict.values()) == 0: if sum(count_dict.values()) == 0:
abstention = 1 abstention = 1
# if the answer token is the only predicted token, the prediction is correct # if the answer token is the only predicted token, the prediction is correct
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1: elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
accuracy = 1 accuracy = 1
return {"score": accuracy, "abstention": abstention} return {"score": accuracy, "abstention": abstention}
""" """
compute the rouge score. compute the rouge score.
hyps and refs are lists of hyposisis and reference strings hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容 empty predictions are replaces with 无内容
""" """
def compute_rouge(hyps, refs): def compute_rouge(hyps, refs):
assert(len(hyps) == len(refs)) assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps] hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps] hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [' '.join(jieba.cut(r)) for r in refs] refs = [' '.join(jieba.cut(r)) for r in refs]
return Rouge().get_scores(hyps, refs) return Rouge().get_scores(hyps, refs)
""" """
compute the gleu score. compute the gleu score.
hyps and refs are lists of hyposisis and reference strings hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容 empty predictions are replaces with 无内容
""" """
def compute_gleu(hyps, refs): def compute_gleu(hyps, refs):
assert(len(hyps) == len(refs)) assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps] hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps] hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [[' '.join(jieba.cut(r))] for r in refs] refs = [[' '.join(jieba.cut(r))] for r in refs]
return corpus_gleu(refs, hyps) return corpus_gleu(refs, hyps)
from typing import List, Tuple from typing import List, Tuple
from modules.alignment import read_cilin, read_confusion, Alignment from modules.alignment import read_cilin, read_confusion, Alignment
from modules.merger import Merger from modules.merger import Merger
from modules.classifier import Classifier from modules.classifier import Classifier
class Annotator: class Annotator:
def __init__(self, def __init__(self,
align: Alignment, align: Alignment,
merger: Merger, merger: Merger,
classifier: Classifier, classifier: Classifier,
granularity: str = "word", granularity: str = "word",
strategy: str = "first"): strategy: str = "first"):
self.align = align self.align = align
self.merger = merger self.merger = merger
self.classifier = classifier self.classifier = classifier
self.granularity = granularity self.granularity = granularity
self.strategy = strategy self.strategy = strategy
@classmethod @classmethod
def create_default(cls, granularity: str = "word", strategy: str = "first"): def create_default(cls, granularity: str = "word", strategy: str = "first"):
""" """
Default parameters used in the paper Default parameters used in the paper
""" """
semantic_dict, semantic_class = read_cilin() semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion() confusion_dict = read_confusion()
align = Alignment(semantic_dict, confusion_dict, granularity) align = Alignment(semantic_dict, confusion_dict, granularity)
merger = Merger(granularity) merger = Merger(granularity)
classifier = Classifier(granularity) classifier = Classifier(granularity)
return cls(align, merger, classifier, granularity, strategy) return cls(align, merger, classifier, granularity, strategy)
def __call__(self, def __call__(self,
src: List[Tuple], src: List[Tuple],
tgt: List[Tuple], tgt: List[Tuple],
annotator_id: int = 0, annotator_id: int = 0,
verbose: bool = False): verbose: bool = False):
""" """
Align sentences and annotate them with error type information Align sentences and annotate them with error type information
""" """
src_tokens = [x[0] for x in src] src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt] tgt_tokens = [x[0] for x in tgt]
src_str = "".join(src_tokens) src_str = "".join(src_tokens)
tgt_str = "".join(tgt_tokens) tgt_str = "".join(tgt_tokens)
# convert to text form # convert to text form
annotations_out = ["S " + " ".join(src_tokens) + "\n"] annotations_out = ["S " + " ".join(src_tokens) + "\n"]
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
annotations_out.append(f"T{annotator_id} 没有错误\n") annotations_out.append(f"T{annotator_id} 没有错误\n")
cors = [tgt_str] cors = [tgt_str]
op, toks, inds = "noop", "-NONE-", (-1, -1) op, toks, inds = "noop", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str) annotations_out.append(a_str)
elif tgt_str == "无法标注": # Not Annotatable Case elif tgt_str == "无法标注": # Not Annotatable Case
annotations_out.append(f"T{annotator_id} 无法标注\n") annotations_out.append(f"T{annotator_id} 无法标注\n")
cors = [tgt_str] cors = [tgt_str]
op, toks, inds = "NA", "-NONE-", (-1, -1) op, toks, inds = "NA", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str) annotations_out.append(a_str)
else: # Other else: # Other
align_objs = self.align(src, tgt) align_objs = self.align(src, tgt)
edit_objs = [] edit_objs = []
align_idx = 0 align_idx = 0
if self.strategy == "first": if self.strategy == "first":
align_objs = align_objs[:1] align_objs = align_objs[:1]
for align_obj in align_objs: for align_obj in align_objs:
edits = self.merger(align_obj, src, tgt, verbose) edits = self.merger(align_obj, src, tgt, verbose)
if edits not in edit_objs: if edits not in edit_objs:
edit_objs.append(edits) edit_objs.append(edits)
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n") annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
align_idx += 1 align_idx += 1
cors = self.classifier(src, tgt, edits, verbose) cors = self.classifier(src, tgt, edits, verbose)
# annotations_out = [] # annotations_out = []
for cor in cors: for cor in cors:
op, toks, inds = cor.op, cor.toks, cor.inds op, toks, inds = cor.op, cor.toks, cor.inds
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str) annotations_out.append(a_str)
annotations_out.append("\n") annotations_out.append("\n")
return annotations_out, cors return annotations_out, cors
from char_smi import CharFuncs from char_smi import CharFuncs
from collections import namedtuple from collections import namedtuple
from pypinyin import pinyin, Style from pypinyin import pinyin, Style
import os import os
Correction = namedtuple( Correction = namedtuple(
"Correction", "Correction",
[ [
"op", "op",
"toks", "toks",
"inds", "inds",
], ],
) )
file_path = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.dirname(os.path.abspath(__file__))
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt')) char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
def check_spell_error(src_span: str, def check_spell_error(src_span: str,
tgt_span: str, tgt_span: str,
threshold: float = 0.8) -> bool: threshold: float = 0.8) -> bool:
if len(src_span) != len(tgt_span): if len(src_span) != len(tgt_span):
return False return False
src_chars = [ch for ch in src_span] src_chars = [ch for ch in src_span]
tgt_chars = [ch for ch in tgt_span] tgt_chars = [ch for ch in tgt_span]
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位 if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
return True return True
for src_char, tgt_char in zip(src_chars, tgt_chars): for src_char, tgt_char in zip(src_chars, tgt_chars):
if src_char != tgt_char: if src_char != tgt_char:
if src_char not in char_smi.data or tgt_char not in char_smi.data: if src_char not in char_smi.data or tgt_char not in char_smi.data:
return False return False
v_sim = char_smi.shape_similarity(src_char, tgt_char) v_sim = char_smi.shape_similarity(src_char, tgt_char)
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char) p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
if v_sim + p_sim < threshold and not ( if v_sim + p_sim < threshold and not (
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])): set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
return False return False
return True return True
class Classifier: class Classifier:
""" """
错误类型分类器 错误类型分类器
""" """
def __init__(self, def __init__(self,
granularity: str = "word"): granularity: str = "word"):
self.granularity = granularity self.granularity = granularity
@staticmethod @staticmethod
def get_pos_type(pos): def get_pos_type(pos):
if pos in {"n", "nd"}: if pos in {"n", "nd"}:
return "NOUN" return "NOUN"
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}: if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
return "NOUN-NE" return "NOUN-NE"
if pos in {"v"}: if pos in {"v"}:
return "VERB" return "VERB"
if pos in {"a", "b"}: if pos in {"a", "b"}:
return "ADJ" return "ADJ"
if pos in {"c"}: if pos in {"c"}:
return "CONJ" return "CONJ"
if pos in {"r"}: if pos in {"r"}:
return "PRON" return "PRON"
if pos in {"d"}: if pos in {"d"}:
return "ADV" return "ADV"
if pos in {"u"}: if pos in {"u"}:
return "AUX" return "AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它 # if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX" # return "SUFFIX"
if pos in {"m"}: if pos in {"m"}:
return "NUM" return "NUM"
if pos in {"p"}: if pos in {"p"}:
return "PREP" return "PREP"
if pos in {"q"}: if pos in {"q"}:
return "QUAN" return "QUAN"
if pos in {"wp"}: if pos in {"wp"}:
return "PUNCT" return "PUNCT"
return "OTHER" return "OTHER"
def __call__(self, def __call__(self,
src, src,
tgt, tgt,
edits, edits,
verbose: bool = False): verbose: bool = False):
""" """
为编辑操作划分错误类型 为编辑操作划分错误类型
:param src: 错误句子信息 :param src: 错误句子信息
:param tgt: 正确句子信息 :param tgt: 正确句子信息
:param edits: 编辑操作 :param edits: 编辑操作
:param verbose: 是否打印信息 :param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作 :return: 划分完错误类型后的编辑操作
""" """
results = [] results = []
src_tokens = [x[0] for x in src] src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt] tgt_tokens = [x[0] for x in tgt]
for edit in edits: for edit in edits:
error_type = edit[0] error_type = edit[0]
src_span = " ".join(src_tokens[edit[1]: edit[2]]) src_span = " ".join(src_tokens[edit[1]: edit[2]])
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]]) tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
# print(tgt_span) # print(tgt_span)
cor = None cor = None
if error_type[0] == "T": if error_type[0] == "T":
cor = Correction("W", tgt_span, (edit[1], edit[2])) cor = Correction("W", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "D": elif error_type[0] == "D":
if self.granularity == "word": # 词级别可以细分错误类型 if self.granularity == "word": # 词级别可以细分错误类型
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2])) cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
else: else:
pos = self.get_pos_type(src[edit[1]][1]) pos = self.get_pos_type(src[edit[1]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2])) cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可 else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("R", "-NONE-", (edit[1], edit[2])) cor = Correction("R", "-NONE-", (edit[1], edit[2]))
elif error_type[0] == "I": elif error_type[0] == "I":
if self.granularity == "word": # 词级别可以细分错误类型 if self.granularity == "word": # 词级别可以细分错误类型
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2])) cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
else: else:
pos = self.get_pos_type(tgt[edit[3]][1]) pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2])) cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可 else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("M", tgt_span, (edit[1], edit[2])) cor = Correction("M", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "S": elif error_type[0] == "S":
if self.granularity == "word": # 词级别可以细分错误类型 if self.granularity == "word": # 词级别可以细分错误类型
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")): if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2])) cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
# Todo 暂且不单独区分命名实体拼写错误 # Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1: # if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else: # else:
# pos = self.get_pos_type(tgt[edit[3]][1]) # pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误 # if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2])) # cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误 # else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else: else:
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2])) cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
else: else:
pos = self.get_pos_type(tgt[edit[3]][1]) pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2])) cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可 else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("S", tgt_span, (edit[1], edit[2])) cor = Correction("S", tgt_span, (edit[1], edit[2]))
results.append(cor) results.append(cor)
if verbose: if verbose:
print("========== Corrections ==========") print("========== Corrections ==========")
for cor in results: for cor in results:
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks)) print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
return results return results
# print(pinyin("朝", style=Style.NORMAL)) # print(pinyin("朝", style=Style.NORMAL))
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment