Unverified Commit 4dd9a3fc authored by Leymore's avatar Leymore Committed by GitHub
Browse files

[Sync] sync with internal codes 20231019 (#488)

parent 2737249f
import argparse import argparse
from collections import Counter from collections import Counter
def main(): def main():
# Parse command line args # Parse command line args
args = parse_args() args = parse_args()
# Open hypothesis and reference m2 files and split into chunks # Open hypothesis and reference m2 files and split into chunks
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n") hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n") ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
# Make sure they have the same number of sentences # Make sure they have the same number of sentences
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2)) assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
# Store global corpus level best counts here # Store global corpus level best counts here
best_dict = Counter({"tp":0, "fp":0, "fn":0}) best_dict = Counter({"tp":0, "fp":0, "fn":0})
best_cats = {} best_cats = {}
# Process each sentence # Process each sentence
sents = zip(hyp_m2, ref_m2) sents = zip(hyp_m2, ref_m2)
for sent_id, sent in enumerate(sents): for sent_id, sent in enumerate(sents):
# Simplify the edits into lists of lists # Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons: # if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id) # sent_id_cons.append(sent_id)
src = sent[0].split("\n")[0] src = sent[0].split("\n")[0]
hyp_edits = simplify_edits(sent[0], args.max_answer_num) hyp_edits = simplify_edits(sent[0], args.max_answer_num)
ref_edits = simplify_edits(sent[1], args.max_answer_num) ref_edits = simplify_edits(sent[1], args.max_answer_num)
# Process the edits for detection/correction based on args # Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args) hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args) ref_dict = process_edits(ref_edits, args)
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num: if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
# Evaluate edits and get best TP, FP, FN hyp+ref combo. # Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(src, count_dict, cat_dict = evaluate_edits(src,
hyp_dict, ref_dict, best_dict, sent_id, args) hyp_dict, ref_dict, best_dict, sent_id, args)
# Merge these dicts with best_dict and best_cats # Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict) best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict) best_cats = merge_dict(best_cats, cat_dict)
# Print results # Print results
print_results(best_dict, best_cats, args) print_results(best_dict, best_cats, args)
# Parse command line args # Parse command line args
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Calculate F-scores for error detection and/or correction.\n" description="Calculate F-scores for error detection and/or correction.\n"
"Flags let you evaluate at different levels of granularity.", "Flags let you evaluate at different levels of granularity.",
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
usage="%(prog)s [options] -hyp HYP -ref REF") usage="%(prog)s [options] -hyp HYP -ref REF")
parser.add_argument( parser.add_argument(
"-hyp", "-hyp",
help="A hypothesis M2 file.", help="A hypothesis M2 file.",
required=True) required=True)
parser.add_argument( parser.add_argument(
"-ref", "-ref",
help="A reference M2 file.", help="A reference M2 file.",
required=True) required=True)
parser.add_argument( parser.add_argument(
"--start", "--start",
type=int, type=int,
default=None default=None
) )
parser.add_argument( parser.add_argument(
"--end", "--end",
type=int, type=int,
default=None default=None
) )
parser.add_argument( parser.add_argument(
"--max_answer_num", "--max_answer_num",
type=int, type=int,
default=None default=None
) )
parser.add_argument( parser.add_argument(
"--reference_num", "--reference_num",
type=int, type=int,
default=None default=None
) )
parser.add_argument( parser.add_argument(
"-b", "-b",
"--beta", "--beta",
help="Value of beta in F-score. (default: 0.5)", help="Value of beta in F-score. (default: 0.5)",
default=0.5, default=0.5,
type=float) type=float)
parser.add_argument( parser.add_argument(
"-v", "-v",
"--verbose", "--verbose",
help="Print verbose output.", help="Print verbose output.",
action="store_true") action="store_true")
eval_type = parser.add_mutually_exclusive_group() eval_type = parser.add_mutually_exclusive_group()
eval_type.add_argument( eval_type.add_argument(
"-dt", "-dt",
help="Evaluate Detection in terms of Tokens.", help="Evaluate Detection in terms of Tokens.",
action="store_true") action="store_true")
eval_type.add_argument( eval_type.add_argument(
"-ds", "-ds",
help="Evaluate Detection in terms of Spans.", help="Evaluate Detection in terms of Spans.",
action="store_true") action="store_true")
eval_type.add_argument( eval_type.add_argument(
"-cs", "-cs",
help="Evaluate Correction in terms of Spans. (default)", help="Evaluate Correction in terms of Spans. (default)",
action="store_true") action="store_true")
eval_type.add_argument( eval_type.add_argument(
"-cse", "-cse",
help="Evaluate Correction in terms of Spans and Error types.", help="Evaluate Correction in terms of Spans and Error types.",
action="store_true") action="store_true")
parser.add_argument( parser.add_argument(
"-single", "-single",
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1", help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
action="store_true") action="store_true")
parser.add_argument( parser.add_argument(
"-multi", "-multi",
help="Only evaluate multi token edits; i.e. 2+:n or n:2+", help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
action="store_true") action="store_true")
parser.add_argument( parser.add_argument(
"-multi_hyp_avg", "-multi_hyp_avg",
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.", help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
action="store_true") # For IAA calculation action="store_true") # For IAA calculation
parser.add_argument( parser.add_argument(
"-multi_hyp_max", "-multi_hyp_max",
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.", help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
action="store_true") # For multiple hypotheses system evaluation action="store_true") # For multiple hypotheses system evaluation
parser.add_argument( parser.add_argument(
"-filt", "-filt",
help="Do not evaluate the specified error types.", help="Do not evaluate the specified error types.",
nargs="+", nargs="+",
default=[]) default=[])
parser.add_argument( parser.add_argument(
"-cat", "-cat",
help="Show error category scores.\n" help="Show error category scores.\n"
"1: Only show operation tier scores; e.g. R.\n" "1: Only show operation tier scores; e.g. R.\n"
"2: Only show main tier scores; e.g. NOUN.\n" "2: Only show main tier scores; e.g. NOUN.\n"
"3: Show all category scores; e.g. R:NOUN.", "3: Show all category scores; e.g. R:NOUN.",
choices=[1, 2, 3], choices=[1, 2, 3],
type=int) type=int)
args = parser.parse_args() args = parser.parse_args()
return args return args
# Input: An m2 format sentence with edits. # Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder] # Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def simplify_edits(sent, max_answer_num): def simplify_edits(sent, max_answer_num):
out_edits = [] out_edits = []
# Get the edit lines from an m2 block. # Get the edit lines from an m2 block.
edits = sent.split("\n") edits = sent.split("\n")
# Loop through the edits # Loop through the edits
for edit in edits: for edit in edits:
# Preprocessing # Preprocessing
if edit.startswith("A "): if edit.startswith("A "):
edit = edit[2:].split("|||") # Ignore "A " then split. edit = edit[2:].split("|||") # Ignore "A " then split.
span = edit[0].split() span = edit[0].split()
start = int(span[0]) start = int(span[0])
end = int(span[1]) end = int(span[1])
cat = edit[1] cat = edit[1]
cor = edit[2].replace(" ", "") cor = edit[2].replace(" ", "")
coder = int(edit[-1]) coder = int(edit[-1])
out_edit = [start, end, cat, cor, coder] out_edit = [start, end, cat, cor, coder]
out_edits.append(out_edit) out_edits.append(out_edit)
# return [edit for edit in out_edits if edit[-1] in [0,1]] # return [edit for edit in out_edits if edit[-1] in [0,1]]
if max_answer_num is None: if max_answer_num is None:
return out_edits return out_edits
elif max_answer_num == 1: elif max_answer_num == 1:
return [edit for edit in out_edits if edit[-1] == 0] return [edit for edit in out_edits if edit[-1] == 0]
elif max_answer_num == 2: elif max_answer_num == 2:
return [edit for edit in out_edits if edit[-1] in [0, 1]] return [edit for edit in out_edits if edit[-1] in [0, 1]]
elif max_answer_num == 3: elif max_answer_num == 3:
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]] return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder] # Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args # Input 2: Command line args
# Output: A dict; key is coder, value is edit dict. # Output: A dict; key is coder, value is edit dict.
def process_edits(edits, args): def process_edits(edits, args):
coder_dict = {} coder_dict = {}
# Add an explicit noop edit if there are no edits. # Add an explicit noop edit if there are no edits.
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]] if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
# Loop through the edits # Loop through the edits
for edit in edits: for edit in edits:
# Name the edit elements for clarity # Name the edit elements for clarity
start = edit[0] start = edit[0]
end = edit[1] end = edit[1]
cat = edit[2] cat = edit[2]
cor = edit[3] cor = edit[3]
coder = edit[4] coder = edit[4]
# Add the coder to the coder_dict if necessary # Add the coder to the coder_dict if necessary
if coder not in coder_dict: coder_dict[coder] = {} if coder not in coder_dict: coder_dict[coder] = {}
# Optionally apply filters based on args # Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction. # 1. UNK type edits are only useful for detection, not correction.
if not args.dt and not args.ds and cat == "UNK": continue if not args.dt and not args.ds and cat == "UNK": continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1 # 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+ # 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if args.multi and end-start < 2 and len(cor.split()) < 2: continue if args.multi and end-start < 2 and len(cor.split()) < 2: continue
# 4. If there is a filter, ignore the specified error types # 4. If there is a filter, ignore the specified error types
if args.filt and cat in args.filt: continue if args.filt and cat in args.filt: continue
# Token Based Detection # Token Based Detection
if args.dt: if args.dt:
# Preserve noop edits. # Preserve noop edits.
if start == -1: if start == -1:
if (start, start) in coder_dict[coder].keys(): if (start, start) in coder_dict[coder].keys():
coder_dict[coder][(start, start)].append(cat) coder_dict[coder][(start, start)].append(cat)
else: else:
coder_dict[coder][(start, start)] = [cat] coder_dict[coder][(start, start)] = [cat]
# Insertions defined as affecting the token on the right # Insertions defined as affecting the token on the right
elif start == end and start >= 0: elif start == end and start >= 0:
if (start, start+1) in coder_dict[coder].keys(): if (start, start+1) in coder_dict[coder].keys():
coder_dict[coder][(start, start+1)].append(cat) coder_dict[coder][(start, start+1)].append(cat)
else: else:
coder_dict[coder][(start, start+1)] = [cat] coder_dict[coder][(start, start+1)] = [cat]
# Edit spans are split for each token in the range. # Edit spans are split for each token in the range.
else: else:
for tok_id in range(start, end): for tok_id in range(start, end):
if (tok_id, tok_id+1) in coder_dict[coder].keys(): if (tok_id, tok_id+1) in coder_dict[coder].keys():
coder_dict[coder][(tok_id, tok_id+1)].append(cat) coder_dict[coder][(tok_id, tok_id+1)].append(cat)
else: else:
coder_dict[coder][(tok_id, tok_id+1)] = [cat] coder_dict[coder][(tok_id, tok_id+1)] = [cat]
# Span Based Detection # Span Based Detection
elif args.ds: elif args.ds:
if (start, end) in coder_dict[coder].keys(): if (start, end) in coder_dict[coder].keys():
coder_dict[coder][(start, end)].append(cat) coder_dict[coder][(start, end)].append(cat)
else: else:
coder_dict[coder][(start, end)] = [cat] coder_dict[coder][(start, end)] = [cat]
# Span Based Correction # Span Based Correction
else: else:
# With error type classification # With error type classification
if args.cse: if args.cse:
if (start, end, cat, cor) in coder_dict[coder].keys(): if (start, end, cat, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cat, cor)].append(cat) coder_dict[coder][(start, end, cat, cor)].append(cat)
else: else:
coder_dict[coder][(start, end, cat, cor)] = [cat] coder_dict[coder][(start, end, cat, cor)] = [cat]
# Without error type classification # Without error type classification
else: else:
if (start, end, cor) in coder_dict[coder].keys(): if (start, end, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cor)].append(cat) coder_dict[coder][(start, end, cor)].append(cat)
else: else:
coder_dict[coder][(start, end, cor)] = [cat] coder_dict[coder][(start, end, cor)] = [cat]
return coder_dict return coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits. # Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits. # Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far. # Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only) # Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args # Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence. # Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict. # Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args): def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
# Store the best sentence level scores and hyp+ref combination IDs # Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result. # best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0 best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
best_cat = {} best_cat = {}
# skip not annotatable sentence # skip not annotatable sentence
if len(ref_dict.keys()) == 1: if len(ref_dict.keys()) == 1:
ref_id = list(ref_dict.keys())[0] ref_id = list(ref_dict.keys())[0]
if len(ref_dict[ref_id].keys()) == 1: if len(ref_dict[ref_id].keys()) == 1:
cat = list(ref_dict[ref_id].values())[0][0] cat = list(ref_dict[ref_id].values())[0][0]
if cat == "NA": if cat == "NA":
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat return best_dict, best_cat
# Compare each hyp and ref combination # Compare each hyp and ref combination
for hyp_id in hyp_dict.keys(): for hyp_id in hyp_dict.keys():
for ref_id in ref_dict.keys(): for ref_id in ref_dict.keys():
# Get the local counts for the current combination. # Get the local counts for the current combination.
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id]) tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
# Compute the local sentence scores (for verbose output only) # Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta) loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores # Compute the global sentence scores
p, r, f = computeFScore( p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
# Save the scores if they are better in terms of: # Save the scores if they are better in terms of:
# 1. Higher F-score # 1. Higher F-score
# 2. Same F-score, higher TP # 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP # 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN # 4. Same F-score, TP and FP, lower FN
if (f > best_f) or \ if (f > best_f) or \
(f == best_f and tp > best_tp) or \ (f == best_f and tp > best_tp) or \
(f == best_f and tp == best_tp and fp < best_fp) or \ (f == best_f and tp == best_tp and fp < best_fp) or \
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn): (f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
best_tp, best_fp, best_fn = tp, fp, fn best_tp, best_fp, best_fn = tp, fp, fn
best_f, best_hyp, best_ref = f, hyp_id, ref_id best_f, best_hyp, best_ref = f, hyp_id, ref_id
best_cat = cat_dict best_cat = cat_dict
# Verbose output # Verbose output
if args.verbose: if args.verbose:
# Prepare verbose output edits. # Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys())) hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys())) ref_verb = list(sorted(ref_dict[ref_id].keys()))
# Ignore noop edits # Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = [] if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = [] if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
# Print verbose info # Print verbose info
print('{:-^40}'.format("")) print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+src[1:]) print("SENTENCE "+str(sent_id)+src[1:])
print('{:-^40}'.format("")) print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id)) print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
print("HYPOTHESIS EDITS :", hyp_verb) print("HYPOTHESIS EDITS :", hyp_verb)
print("REFERENCE EDITS :", ref_verb) print("REFERENCE EDITS :", ref_verb)
print("Local TP/FP/FN :", str(tp), str(fp), str(fn)) print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f)) print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"])) print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f)) print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
# Verbose output: display the best hyp+ref combination # Verbose output: display the best hyp+ref combination
if args.verbose: if args.verbose:
print('{:-^40}'.format("")) print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id)) print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict # Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat return best_dict, best_cat
# Input 1: A dictionary of hypothesis edits for a single system. # Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator. # Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator. # Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts. # Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits): def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives tp = 0 # True Positives
fp = 0 # False Positives fp = 0 # False Positives
fn = 0 # False Negatives fn = 0 # False Negatives
cat_dict = {} # {cat: [tp, fp, fn], ...} cat_dict = {} # {cat: [tp, fp, fn], ...}
for h_edit, h_cats in hyp_edits.items(): for h_edit, h_cats in hyp_edits.items():
# noop hyp edits cannot be TP or FP # noop hyp edits cannot be TP or FP
if h_cats[0] == "noop": continue if h_cats[0] == "noop": continue
# TRUE POSITIVES # TRUE POSITIVES
if h_edit in ref_edits.keys(): if h_edit in ref_edits.keys():
# On occasion, multiple tokens at same span. # On occasion, multiple tokens at same span.
for h_cat in ref_edits[h_edit]: # Use ref dict for TP for h_cat in ref_edits[h_edit]: # Use ref dict for TP
tp += 1 tp += 1
# Each dict value [TP, FP, FN] # Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys(): if h_cat in cat_dict.keys():
cat_dict[h_cat][0] += 1 cat_dict[h_cat][0] += 1
else: else:
cat_dict[h_cat] = [1, 0, 0] cat_dict[h_cat] = [1, 0, 0]
# FALSE POSITIVES # FALSE POSITIVES
else: else:
# On occasion, multiple tokens at same span. # On occasion, multiple tokens at same span.
for h_cat in h_cats: for h_cat in h_cats:
fp += 1 fp += 1
# Each dict value [TP, FP, FN] # Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys(): if h_cat in cat_dict.keys():
cat_dict[h_cat][1] += 1 cat_dict[h_cat][1] += 1
else: else:
cat_dict[h_cat] = [0, 1, 0] cat_dict[h_cat] = [0, 1, 0]
for r_edit, r_cats in ref_edits.items(): for r_edit, r_cats in ref_edits.items():
# noop ref edits cannot be FN # noop ref edits cannot be FN
if r_cats[0] == "noop": continue if r_cats[0] == "noop": continue
# FALSE NEGATIVES # FALSE NEGATIVES
if r_edit not in hyp_edits.keys(): if r_edit not in hyp_edits.keys():
# On occasion, multiple tokens at same span. # On occasion, multiple tokens at same span.
for r_cat in r_cats: for r_cat in r_cats:
fn += 1 fn += 1
# Each dict value [TP, FP, FN] # Each dict value [TP, FP, FN]
if r_cat in cat_dict.keys(): if r_cat in cat_dict.keys():
cat_dict[r_cat][2] += 1 cat_dict[r_cat][2] += 1
else: else:
cat_dict[r_cat] = [0, 0, 1] cat_dict[r_cat] = [0, 0, 1]
return tp, fp, fn, cat_dict return tp, fp, fn, cat_dict
# Input 1-3: True positives, false positives, false negatives # Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score. # Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp. # Output 1-3: Precision, Recall and F-score rounded to 4dp.
def computeFScore(tp, fp, fn, beta): def computeFScore(tp, fp, fn, beta):
p = float(tp)/(tp+fp) if fp else 1.0 p = float(tp)/(tp+fp) if fp else 1.0
r = float(tp)/(tp+fn) if fn else 1.0 r = float(tp)/(tp+fn) if fn else 1.0
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4) return round(p, 4), round(r, 4), round(f, 4)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN. # Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN. # Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2): def merge_dict(dict1, dict2):
for cat, stats in dict2.items(): for cat, stats in dict2.items():
if cat in dict1.keys(): if cat in dict1.keys():
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)] dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
else: else:
dict1[cat] = stats dict1[cat] = stats
return dict1 return dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn] # Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity. # Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything. # 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2. # Output: A dictionary of category TP, FP and FN based on Input 2.
def processCategories(cat_dict, setting): def processCategories(cat_dict, setting):
# Otherwise, do some processing. # Otherwise, do some processing.
proc_cat_dict = {} proc_cat_dict = {}
for cat, cnt in cat_dict.items(): for cat, cnt in cat_dict.items():
if cat == "UNK": if cat == "UNK":
proc_cat_dict[cat] = cnt proc_cat_dict[cat] = cnt
continue continue
# M, U, R or UNK combined only. # M, U, R or UNK combined only.
if setting == 1: if setting == 1:
if cat[0] in proc_cat_dict.keys(): if cat[0] in proc_cat_dict.keys():
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)] proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
else: else:
proc_cat_dict[cat[0]] = cnt proc_cat_dict[cat[0]] = cnt
# Everything without M, U or R. # Everything without M, U or R.
elif setting == 2: elif setting == 2:
if cat[2:] in proc_cat_dict.keys(): if cat[2:] in proc_cat_dict.keys():
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)] proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
else: else:
proc_cat_dict[cat[2:]] = cnt proc_cat_dict[cat[2:]] = cnt
# All error category combinations # All error category combinations
else: else:
return cat_dict return cat_dict
return proc_cat_dict return proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs # Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs # Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args # Input 3: Command line args
def print_results(best, best_cats, args): def print_results(best, best_cats, args):
# Prepare output title. # Prepare output title.
if args.dt: title = " Token-Based Detection " if args.dt: title = " Token-Based Detection "
elif args.ds: title = " Span-Based Detection " elif args.ds: title = " Span-Based Detection "
elif args.cse: title = " Span-Based Correction + Classification " elif args.cse: title = " Span-Based Correction + Classification "
else: title = " Span-Based Correction " else: title = " Span-Based Correction "
# Category Scores # Category Scores
if args.cat: if args.cat:
best_cats = processCategories(best_cats, args.cat) best_cats = processCategories(best_cats, args.cat)
print("") print("")
print('{:=^66}'.format(title)) print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8), print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta)) "P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()): for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta) cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8), print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f) str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
# Print the overall results. # Print the overall results.
print("") print("")
print('{:=^46}'.format(title)) print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)])) print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"], print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta))))) best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format("")) print('{:=^46}'.format(""))
print("") print("")
if __name__ == "__main__": if __name__ == "__main__":
# Run the program # Run the program
main() main()
from rouge_chinese import Rouge from rouge_chinese import Rouge
import jieba import jieba
from nltk.translate.gleu_score import corpus_gleu from nltk.translate.gleu_score import corpus_gleu
def compute_f1_two_sets(pred_set, gt_set): def compute_f1_two_sets(pred_set, gt_set):
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0 precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0 recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return f1 return f1
def multi_choice_judge(prediction, option_list, answer_token): def multi_choice_judge(prediction, option_list, answer_token):
# a dict, key: letters in the option list, value: count of the letter in the prediction # a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict, abstention, accuracy = {}, 0, 0 count_dict, abstention, accuracy = {}, 0, 0
for option in option_list: for option in option_list:
option_count = prediction.count(option) option_count = prediction.count(option)
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1 count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
if sum(count_dict.values()) == 0: if sum(count_dict.values()) == 0:
abstention = 1 abstention = 1
# if the answer token is the only predicted token, the prediction is correct # if the answer token is the only predicted token, the prediction is correct
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1: elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
accuracy = 1 accuracy = 1
return {"score": accuracy, "abstention": abstention} return {"score": accuracy, "abstention": abstention}
""" """
compute the rouge score. compute the rouge score.
hyps and refs are lists of hyposisis and reference strings hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容 empty predictions are replaces with 无内容
""" """
def compute_rouge(hyps, refs): def compute_rouge(hyps, refs):
assert(len(hyps) == len(refs)) assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps] hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps] hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [' '.join(jieba.cut(r)) for r in refs] refs = [' '.join(jieba.cut(r)) for r in refs]
return Rouge().get_scores(hyps, refs) return Rouge().get_scores(hyps, refs)
""" """
compute the gleu score. compute the gleu score.
hyps and refs are lists of hyposisis and reference strings hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容 empty predictions are replaces with 无内容
""" """
def compute_gleu(hyps, refs): def compute_gleu(hyps, refs):
assert(len(hyps) == len(refs)) assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps] hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps] hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [[' '.join(jieba.cut(r))] for r in refs] refs = [[' '.join(jieba.cut(r))] for r in refs]
return corpus_gleu(refs, hyps) return corpus_gleu(refs, hyps)
import numpy as np import numpy as np
from typing import List, Tuple, Dict from typing import List, Tuple, Dict
from modules.tokenizer import Tokenizer from modules.tokenizer import Tokenizer
import os import os
from string import punctuation from string import punctuation
REAL_PATH = os.path.split(os.path.realpath(__file__))[0] REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏" chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct = punctuation english_punct = punctuation
punct = chinese_punct + english_punct punct = chinese_punct + english_punct
def check_all_chinese(word): def check_all_chinese(word):
""" """
判断一个单词是否全部由中文组成 判断一个单词是否全部由中文组成
:param word: :param word:
:return: :return:
""" """
return all(['\u4e00' <= ch <= '\u9fff' for ch in word]) return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
def read_cilin(): def read_cilin():
""" """
Cilin 詞林 is a thesaurus with semantic information Cilin 詞林 is a thesaurus with semantic information
""" """
# TODO -- fix this path # TODO -- fix this path
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n") lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
semantic_dict = {} semantic_dict = {}
semantic_classes = {} semantic_classes = {}
for line in lines: for line in lines:
code, *words = line.split(" ") code, *words = line.split(" ")
for word in words: for word in words:
semantic_dict[word] = code semantic_dict[word] = code
# make reverse dict # make reverse dict
if code in semantic_classes: if code in semantic_classes:
semantic_classes[code] += words semantic_classes[code] += words
else: else:
semantic_classes[code] = words semantic_classes[code] = words
return semantic_dict, semantic_classes return semantic_dict, semantic_classes
def read_confusion(): def read_confusion():
confusion_dict = {} confusion_dict = {}
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f: with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
for line in f: for line in f:
li = line.rstrip('\n').split(" ") li = line.rstrip('\n').split(" ")
confusion_dict[li[0]] = li[1:] confusion_dict[li[0]] = li[1:]
return confusion_dict return confusion_dict
class Alignment: class Alignment:
""" """
对齐错误句子和正确句子, 对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作 使用编辑距离算法抽取编辑操作
""" """
def __init__( def __init__(
self, self,
semantic_dict: Dict, semantic_dict: Dict,
confusion_dict: Dict, confusion_dict: Dict,
granularity: str = "word", granularity: str = "word",
) -> None: ) -> None:
""" """
构造函数 构造函数
:param semantic_dict: 语义词典(大词林) :param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集 :param confusion_dict: 字符混淆集
""" """
self.insertion_cost = 1 self.insertion_cost = 1
self.deletion_cost = 1 self.deletion_cost = 1
self.semantic_dict = semantic_dict self.semantic_dict = semantic_dict
self.confusion_dict = confusion_dict self.confusion_dict = confusion_dict
# Because we use character level tokenization, this doesn't currently use POS # Because we use character level tokenization, this doesn't currently use POS
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
self.granularity = granularity # word-level or character-level self.granularity = granularity # word-level or character-level
self.align_seqs = [] self.align_seqs = []
def __call__(self, def __call__(self,
src: List[Tuple], src: List[Tuple],
tgt: List[Tuple], tgt: List[Tuple],
verbose: bool = False): verbose: bool = False):
cost_matrix, oper_matrix = self.align(src, tgt) cost_matrix, oper_matrix = self.align(src, tgt)
align_seq = self.get_cheapest_align_seq(oper_matrix) align_seq = self.get_cheapest_align_seq(oper_matrix)
if verbose: if verbose:
print("========== Seg. and POS: ==========") print("========== Seg. and POS: ==========")
print(src) print(src)
print(tgt) print(tgt)
print("========== Cost Matrix ==========") print("========== Cost Matrix ==========")
print(cost_matrix) print(cost_matrix)
print("========== Oper Matrix ==========") print("========== Oper Matrix ==========")
print(oper_matrix) print(oper_matrix)
print("========== Alignment ==========") print("========== Alignment ==========")
print(align_seq) print(align_seq)
print("========== Results ==========") print("========== Results ==========")
for a in align_seq: for a in align_seq:
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]]) print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
return align_seq return align_seq
def _get_semantic_class(self, word): def _get_semantic_class(self, word):
""" """
NOTE: Based on the paper: NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别) 获取每个词语的语义类别(基于大词林,有三个级别)
""" """
if word in self.semantic_dict: if word in self.semantic_dict:
code = self.semantic_dict[word] code = self.semantic_dict[word]
high, mid, low = code[0], code[1], code[2:4] high, mid, low = code[0], code[1], code[2:4]
return high, mid, low return high, mid, low
else: # unknown else: # unknown
return None return None
@staticmethod @staticmethod
def _get_class_diff(a_class, b_class): def _get_class_diff(a_class, b_class):
""" """
d == 3 for equivalent semantics d == 3 for equivalent semantics
d == 0 for completely different semantics d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距 根据大词林的信息,计算两个词的语义类别的差距
""" """
d = sum([a == b for a, b in zip(a_class, b_class)]) d = sum([a == b for a, b in zip(a_class, b_class)])
return d return d
def _get_semantic_cost(self, a, b): def _get_semantic_cost(self, a, b):
""" """
计算基于语义信息的替换操作cost 计算基于语义信息的替换操作cost
:param a: 单词a的语义类别 :param a: 单词a的语义类别
:param b: 单词b的语义类别 :param b: 单词b的语义类别
:return: 替换编辑代价 :return: 替换编辑代价
""" """
a_class = self._get_semantic_class(a) a_class = self._get_semantic_class(a)
b_class = self._get_semantic_class(b) b_class = self._get_semantic_class(b)
# unknown class, default to 1 # unknown class, default to 1
if a_class is None or b_class is None: if a_class is None or b_class is None:
return 4 return 4
elif a_class == b_class: elif a_class == b_class:
return 0 return 0
else: else:
return 2 * (3 - self._get_class_diff(a_class, b_class)) return 2 * (3 - self._get_class_diff(a_class, b_class))
def _get_pos_cost(self, a_pos, b_pos): def _get_pos_cost(self, a_pos, b_pos):
""" """
计算基于词性信息的编辑距离cost 计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性 :param a_pos: 单词a的词性
:param b_pos: 单词b的词性 :param b_pos: 单词b的词性
:return: 替换编辑代价 :return: 替换编辑代价
""" """
if a_pos == b_pos: if a_pos == b_pos:
return 0 return 0
elif a_pos in self._open_pos and b_pos in self._open_pos: elif a_pos in self._open_pos and b_pos in self._open_pos:
return 0.25 return 0.25
else: else:
return 0.499 return 0.499
def _get_char_cost(self, a, b, pinyin_a, pinyin_b): def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
""" """
NOTE: This is a replacement of ERRANTS lemma cost for Chinese NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost 计算基于字符相似度的编辑距离cost
""" """
if not (check_all_chinese(a) and check_all_chinese(b)): if not (check_all_chinese(a) and check_all_chinese(b)):
return 0.5 return 0.5
if len(a) > len(b): if len(a) > len(b):
a, b = b, a a, b = b, a
pinyin_a, pinyin_b = pinyin_b, pinyin_a pinyin_a, pinyin_b = pinyin_b, pinyin_a
if a == b: if a == b:
return 0 return 0
else: else:
return self._get_spell_cost(a, b, pinyin_a, pinyin_b) return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b): def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
""" """
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成 计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a :param a: 单词a
:param b: 单词b,且单词a的长度小于等于b :param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音 :param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音 :param pinyin_b: 单词b的拼音
:return: 替换操作cost :return: 替换操作cost
""" """
count = 0 count = 0
for i in range(len(a)): for i in range(len(a)):
for j in range(len(b)): for j in range(len(b)):
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]): if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
count += 1 count += 1
break break
return (len(a) - count) / (len(a) * 2) return (len(a) - count) / (len(a) * 2)
def get_sub_cost(self, a_seg, b_seg): def get_sub_cost(self, a_seg, b_seg):
""" """
Calculate the substitution cost between words a and b Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加 计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
""" """
if a_seg[0] == b_seg[0]: if a_seg[0] == b_seg[0]:
return 0 return 0
if self.granularity == "word": # 词级别可以额外利用词性信息 if self.granularity == "word": # 词级别可以额外利用词性信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0 semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1]) pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2]) char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + pos_cost + char_cost return semantic_cost + pos_cost + char_cost
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息 else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0 semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
if a_seg[0] in punct and b_seg[0] in punct: if a_seg[0] in punct and b_seg[0] in punct:
pos_cost = 0.0 pos_cost = 0.0
elif a_seg[0] not in punct and b_seg[0] not in punct: elif a_seg[0] not in punct and b_seg[0] not in punct:
pos_cost = 0.25 pos_cost = 0.25
else: else:
pos_cost = 0.499 pos_cost = 0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5 # pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2]) char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + char_cost + pos_cost return semantic_cost + char_cost + pos_cost
def align(self, def align(self,
src: List[Tuple], src: List[Tuple],
tgt: List[Tuple]): tgt: List[Tuple]):
""" """
Based on ERRANT's alignment Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。 基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别: 编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变 1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除 2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入 3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题 4) T:Transposition,移位操作,即涉及到词序问题
""" """
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵 cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
oper_matrix = np.full( oper_matrix = np.full(
(len(src) + 1, len(tgt) + 1), "O", dtype=object (len(src) + 1, len(tgt) + 1), "O", dtype=object
) # 操作矩阵 ) # 操作矩阵
# Fill in the edges # Fill in the edges
for i in range(1, len(src) + 1): for i in range(1, len(src) + 1):
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1 cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
oper_matrix[i][0] = ["D"] oper_matrix[i][0] = ["D"]
for j in range(1, len(tgt) + 1): for j in range(1, len(tgt) + 1):
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1 cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
oper_matrix[0][j] = ["I"] oper_matrix[0][j] = ["I"]
# Loop through the cost matrix # Loop through the cost matrix
for i in range(len(src)): for i in range(len(src)):
for j in range(len(tgt)): for j in range(len(tgt)):
# Matches # Matches
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0 if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix[i + 1][j + 1] = cost_matrix[i][j] cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
oper_matrix[i + 1][j + 1] = ["M"] oper_matrix[i + 1][j + 1] = ["M"]
# Non-matches # Non-matches
else: else:
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
sub_cost = cost_matrix[i][j] + self.get_sub_cost( sub_cost = cost_matrix[i][j] + self.get_sub_cost(
src[i], tgt[j] src[i], tgt[j]
) # 由替换动作得到的总cost ) # 由替换动作得到的总cost
# Calculate transposition cost # Calculate transposition cost
# 计算移位操作的总cost # 计算移位操作的总cost
trans_cost = float("inf") trans_cost = float("inf")
k = 1 k = 1
while ( while (
i - k >= 0 i - k >= 0
and j - k >= 0 and j - k >= 0
and cost_matrix[i - k + 1][j - k + 1] and cost_matrix[i - k + 1][j - k + 1]
!= cost_matrix[i - k][j - k] != cost_matrix[i - k][j - k]
): ):
p1 = sorted([a[0] for a in src][i - k: i + 1]) p1 = sorted([a[0] for a in src][i - k: i + 1])
p2 = sorted([b[0] for b in tgt][j - k: j + 1]) p2 = sorted([b[0] for b in tgt][j - k: j + 1])
if p1 == p2: if p1 == p2:
trans_cost = cost_matrix[i - k][j - k] + k trans_cost = cost_matrix[i - k][j - k] + k
break break
k += 1 k += 1
costs = [trans_cost, sub_cost, ins_cost, del_cost] costs = [trans_cost, sub_cost, ins_cost, del_cost]
ind = costs.index(min(costs)) ind = costs.index(min(costs))
cost_matrix[i + 1][j + 1] = costs[ind] cost_matrix[i + 1][j + 1] = costs[ind]
# ind = costs.index(costs[ind], ind+1) # ind = costs.index(costs[ind], ind+1)
for idx, cost in enumerate(costs): for idx, cost in enumerate(costs):
if cost == costs[ind]: if cost == costs[ind]:
if idx == 0: if idx == 0:
if oper_matrix[i + 1][j + 1] == "O": if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)] oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
else: else:
oper_matrix[i + 1][j + 1].append("T" + str(k + 1)) oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
elif idx == 1: elif idx == 1:
if oper_matrix[i + 1][j + 1] == "O": if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["S"] oper_matrix[i + 1][j + 1] = ["S"]
else: else:
oper_matrix[i + 1][j + 1].append("S") oper_matrix[i + 1][j + 1].append("S")
elif idx == 2: elif idx == 2:
if oper_matrix[i + 1][j + 1] == "O": if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["I"] oper_matrix[i + 1][j + 1] = ["I"]
else: else:
oper_matrix[i + 1][j + 1].append("I") oper_matrix[i + 1][j + 1].append("I")
else: else:
if oper_matrix[i + 1][j + 1] == "O": if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["D"] oper_matrix[i + 1][j + 1] = ["D"]
else: else:
oper_matrix[i + 1][j + 1].append("D") oper_matrix[i + 1][j + 1].append("D")
return cost_matrix, oper_matrix return cost_matrix, oper_matrix
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"): def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
""" """
深度优先遍历,获取最小编辑距离相同的所有序列 深度优先遍历,获取最小编辑距离相同的所有序列
""" """
if i + j == 0: if i + j == 0:
self.align_seqs.append(align_seq_now) self.align_seqs.append(align_seq_now)
else: else:
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径 ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if strategy != "all": ops = ops[:1] if strategy != "all": ops = ops[:1]
for op in ops: for op in ops:
if op in {"M", "S"}: if op in {"M", "S"}:
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy) self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
elif op == "D": elif op == "D":
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy) self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
elif op == "I": elif op == "I":
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy) self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
else: else:
k = int(op[1:]) k = int(op[1:])
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy) self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
def get_cheapest_align_seq(self, oper_matrix): def get_cheapest_align_seq(self, oper_matrix):
""" """
回溯获得编辑距离最小的编辑序列 回溯获得编辑距离最小的编辑序列
""" """
self.align_seqs = [] self.align_seqs = []
i = oper_matrix.shape[0] - 1 i = oper_matrix.shape[0] - 1
j = oper_matrix.shape[1] - 1 j = oper_matrix.shape[1] - 1
if abs(i - j) > 10: if abs(i - j) > 10:
self._dfs(i, j , [], oper_matrix, "first") self._dfs(i, j , [], oper_matrix, "first")
else: else:
self._dfs(i, j , [], oper_matrix, "all") self._dfs(i, j , [], oper_matrix, "all")
final_align_seqs = [seq[::-1] for seq in self.align_seqs] final_align_seqs = [seq[::-1] for seq in self.align_seqs]
return final_align_seqs return final_align_seqs
if __name__ == "__main__": if __name__ == "__main__":
tokenizer = Tokenizer("word") tokenizer = Tokenizer("word")
semantic_dict, semantic_class = read_cilin() semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion() confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict) alignment = Alignment(semantic_dict, confusion_dict)
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")] sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
src, tgt = tokenizer(sents) src, tgt = tokenizer(sents)
alignment(src, tgt, verbose=True) alignment(src, tgt, verbose=True)
\ No newline at end of file
from typing import List, Tuple from typing import List, Tuple
from modules.alignment import read_cilin, read_confusion, Alignment from modules.alignment import read_cilin, read_confusion, Alignment
from modules.merger import Merger from modules.merger import Merger
from modules.classifier import Classifier from modules.classifier import Classifier
class Annotator: class Annotator:
def __init__(self, def __init__(self,
align: Alignment, align: Alignment,
merger: Merger, merger: Merger,
classifier: Classifier, classifier: Classifier,
granularity: str = "word", granularity: str = "word",
strategy: str = "first"): strategy: str = "first"):
self.align = align self.align = align
self.merger = merger self.merger = merger
self.classifier = classifier self.classifier = classifier
self.granularity = granularity self.granularity = granularity
self.strategy = strategy self.strategy = strategy
@classmethod @classmethod
def create_default(cls, granularity: str = "word", strategy: str = "first"): def create_default(cls, granularity: str = "word", strategy: str = "first"):
""" """
Default parameters used in the paper Default parameters used in the paper
""" """
semantic_dict, semantic_class = read_cilin() semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion() confusion_dict = read_confusion()
align = Alignment(semantic_dict, confusion_dict, granularity) align = Alignment(semantic_dict, confusion_dict, granularity)
merger = Merger(granularity) merger = Merger(granularity)
classifier = Classifier(granularity) classifier = Classifier(granularity)
return cls(align, merger, classifier, granularity, strategy) return cls(align, merger, classifier, granularity, strategy)
def __call__(self, def __call__(self,
src: List[Tuple], src: List[Tuple],
tgt: List[Tuple], tgt: List[Tuple],
annotator_id: int = 0, annotator_id: int = 0,
verbose: bool = False): verbose: bool = False):
""" """
Align sentences and annotate them with error type information Align sentences and annotate them with error type information
""" """
src_tokens = [x[0] for x in src] src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt] tgt_tokens = [x[0] for x in tgt]
src_str = "".join(src_tokens) src_str = "".join(src_tokens)
tgt_str = "".join(tgt_tokens) tgt_str = "".join(tgt_tokens)
# convert to text form # convert to text form
annotations_out = ["S " + " ".join(src_tokens) + "\n"] annotations_out = ["S " + " ".join(src_tokens) + "\n"]
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
annotations_out.append(f"T{annotator_id} 没有错误\n") annotations_out.append(f"T{annotator_id} 没有错误\n")
cors = [tgt_str] cors = [tgt_str]
op, toks, inds = "noop", "-NONE-", (-1, -1) op, toks, inds = "noop", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str) annotations_out.append(a_str)
elif tgt_str == "无法标注": # Not Annotatable Case elif tgt_str == "无法标注": # Not Annotatable Case
annotations_out.append(f"T{annotator_id} 无法标注\n") annotations_out.append(f"T{annotator_id} 无法标注\n")
cors = [tgt_str] cors = [tgt_str]
op, toks, inds = "NA", "-NONE-", (-1, -1) op, toks, inds = "NA", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str) annotations_out.append(a_str)
else: # Other else: # Other
align_objs = self.align(src, tgt) align_objs = self.align(src, tgt)
edit_objs = [] edit_objs = []
align_idx = 0 align_idx = 0
if self.strategy == "first": if self.strategy == "first":
align_objs = align_objs[:1] align_objs = align_objs[:1]
for align_obj in align_objs: for align_obj in align_objs:
edits = self.merger(align_obj, src, tgt, verbose) edits = self.merger(align_obj, src, tgt, verbose)
if edits not in edit_objs: if edits not in edit_objs:
edit_objs.append(edits) edit_objs.append(edits)
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n") annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
align_idx += 1 align_idx += 1
cors = self.classifier(src, tgt, edits, verbose) cors = self.classifier(src, tgt, edits, verbose)
# annotations_out = [] # annotations_out = []
for cor in cors: for cor in cors:
op, toks, inds = cor.op, cor.toks, cor.inds op, toks, inds = cor.op, cor.toks, cor.inds
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n" a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str) annotations_out.append(a_str)
annotations_out.append("\n") annotations_out.append("\n")
return annotations_out, cors return annotations_out, cors
from char_smi import CharFuncs from char_smi import CharFuncs
from collections import namedtuple from collections import namedtuple
from pypinyin import pinyin, Style from pypinyin import pinyin, Style
import os import os
Correction = namedtuple( Correction = namedtuple(
"Correction", "Correction",
[ [
"op", "op",
"toks", "toks",
"inds", "inds",
], ],
) )
file_path = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.dirname(os.path.abspath(__file__))
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt')) char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
def check_spell_error(src_span: str, def check_spell_error(src_span: str,
tgt_span: str, tgt_span: str,
threshold: float = 0.8) -> bool: threshold: float = 0.8) -> bool:
if len(src_span) != len(tgt_span): if len(src_span) != len(tgt_span):
return False return False
src_chars = [ch for ch in src_span] src_chars = [ch for ch in src_span]
tgt_chars = [ch for ch in tgt_span] tgt_chars = [ch for ch in tgt_span]
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位 if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
return True return True
for src_char, tgt_char in zip(src_chars, tgt_chars): for src_char, tgt_char in zip(src_chars, tgt_chars):
if src_char != tgt_char: if src_char != tgt_char:
if src_char not in char_smi.data or tgt_char not in char_smi.data: if src_char not in char_smi.data or tgt_char not in char_smi.data:
return False return False
v_sim = char_smi.shape_similarity(src_char, tgt_char) v_sim = char_smi.shape_similarity(src_char, tgt_char)
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char) p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
if v_sim + p_sim < threshold and not ( if v_sim + p_sim < threshold and not (
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])): set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
return False return False
return True return True
class Classifier: class Classifier:
""" """
错误类型分类器 错误类型分类器
""" """
def __init__(self, def __init__(self,
granularity: str = "word"): granularity: str = "word"):
self.granularity = granularity self.granularity = granularity
@staticmethod @staticmethod
def get_pos_type(pos): def get_pos_type(pos):
if pos in {"n", "nd"}: if pos in {"n", "nd"}:
return "NOUN" return "NOUN"
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}: if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
return "NOUN-NE" return "NOUN-NE"
if pos in {"v"}: if pos in {"v"}:
return "VERB" return "VERB"
if pos in {"a", "b"}: if pos in {"a", "b"}:
return "ADJ" return "ADJ"
if pos in {"c"}: if pos in {"c"}:
return "CONJ" return "CONJ"
if pos in {"r"}: if pos in {"r"}:
return "PRON" return "PRON"
if pos in {"d"}: if pos in {"d"}:
return "ADV" return "ADV"
if pos in {"u"}: if pos in {"u"}:
return "AUX" return "AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它 # if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX" # return "SUFFIX"
if pos in {"m"}: if pos in {"m"}:
return "NUM" return "NUM"
if pos in {"p"}: if pos in {"p"}:
return "PREP" return "PREP"
if pos in {"q"}: if pos in {"q"}:
return "QUAN" return "QUAN"
if pos in {"wp"}: if pos in {"wp"}:
return "PUNCT" return "PUNCT"
return "OTHER" return "OTHER"
def __call__(self, def __call__(self,
src, src,
tgt, tgt,
edits, edits,
verbose: bool = False): verbose: bool = False):
""" """
为编辑操作划分错误类型 为编辑操作划分错误类型
:param src: 错误句子信息 :param src: 错误句子信息
:param tgt: 正确句子信息 :param tgt: 正确句子信息
:param edits: 编辑操作 :param edits: 编辑操作
:param verbose: 是否打印信息 :param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作 :return: 划分完错误类型后的编辑操作
""" """
results = [] results = []
src_tokens = [x[0] for x in src] src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt] tgt_tokens = [x[0] for x in tgt]
for edit in edits: for edit in edits:
error_type = edit[0] error_type = edit[0]
src_span = " ".join(src_tokens[edit[1]: edit[2]]) src_span = " ".join(src_tokens[edit[1]: edit[2]])
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]]) tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
# print(tgt_span) # print(tgt_span)
cor = None cor = None
if error_type[0] == "T": if error_type[0] == "T":
cor = Correction("W", tgt_span, (edit[1], edit[2])) cor = Correction("W", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "D": elif error_type[0] == "D":
if self.granularity == "word": # 词级别可以细分错误类型 if self.granularity == "word": # 词级别可以细分错误类型
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2])) cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
else: else:
pos = self.get_pos_type(src[edit[1]][1]) pos = self.get_pos_type(src[edit[1]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2])) cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可 else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("R", "-NONE-", (edit[1], edit[2])) cor = Correction("R", "-NONE-", (edit[1], edit[2]))
elif error_type[0] == "I": elif error_type[0] == "I":
if self.granularity == "word": # 词级别可以细分错误类型 if self.granularity == "word": # 词级别可以细分错误类型
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2])) cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
else: else:
pos = self.get_pos_type(tgt[edit[3]][1]) pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2])) cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可 else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("M", tgt_span, (edit[1], edit[2])) cor = Correction("M", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "S": elif error_type[0] == "S":
if self.granularity == "word": # 词级别可以细分错误类型 if self.granularity == "word": # 词级别可以细分错误类型
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")): if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2])) cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
# Todo 暂且不单独区分命名实体拼写错误 # Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1: # if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else: # else:
# pos = self.get_pos_type(tgt[edit[3]][1]) # pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误 # if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2])) # cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误 # else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2])) # cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else: else:
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2])) cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
else: else:
pos = self.get_pos_type(tgt[edit[3]][1]) pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2])) cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可 else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("S", tgt_span, (edit[1], edit[2])) cor = Correction("S", tgt_span, (edit[1], edit[2]))
results.append(cor) results.append(cor)
if verbose: if verbose:
print("========== Corrections ==========") print("========== Corrections ==========")
for cor in results: for cor in results:
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks)) print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
return results return results
# print(pinyin("朝", style=Style.NORMAL)) # print(pinyin("朝", style=Style.NORMAL))
from itertools import groupby from itertools import groupby
from string import punctuation from string import punctuation
from typing import List from typing import List
from modules.tokenizer import Tokenizer from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein import Levenshtein
class Merger: class Merger:
""" """
合并编辑操作,从Token-Level转换为Span-Level 合并编辑操作,从Token-Level转换为Span-Level
""" """
def __init__(self, def __init__(self,
granularity: str = "word", granularity: str = "word",
merge: bool = False): merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧." chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation] self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity self.granularity = granularity
self.merge = merge self.merge = merge
@staticmethod @staticmethod
def _merge_edits(seq, tag="X"): def _merge_edits(seq, tag="X"):
if seq: if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])] return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else: else:
return seq return seq
@staticmethod @staticmethod
def _check_revolve(span_a, span_b): def _check_revolve(span_a, span_b):
span_a = span_a + span_a span_a = span_a + span_a
return span_b in span_a return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens): def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1: if len(seq) <= 1:
return seq return seq
ops = [op[0] for op in seq] ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}: if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop()) return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}: if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt # do not merge this pattern_from_qua.txt
return seq return seq
if set(ops) == {"S"}: if set(ops) == {"S"}:
if self.granularity == "word": if self.granularity == "word":
return seq return seq
else: else:
return self._merge_edits(seq, "S") return self._merge_edits(seq, "S")
if set(ops) == {"M"}: if set(ops) == {"M"}:
return self._merge_edits(seq, "M") return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S") return self._merge_edits(seq, "S")
def __call__(self, def __call__(self,
align_obj, align_obj,
src: List, src: List,
tgt: List, tgt: List,
verbose: bool = False): verbose: bool = False):
""" """
Based on ERRANT's merge, adapted for Chinese Based on ERRANT's merge, adapted for Chinese
""" """
src_tokens = [x[0] for x in src] src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt] tgt_tokens = [x[0] for x in tgt]
edits = [] edits = []
# Split alignment into groups of M, T and rest. (T has a number after it) # Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并 # Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并 # Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby( for op, group in groupby(
align_obj, align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False, lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
): ):
group = list(group) group = list(group)
# T is always split TODO: Evaluate this # T is always split TODO: Evaluate this
if op == "T": if op == "T":
for seq in group: for seq in group:
edits.append(seq) edits.append(seq)
# Process D, I and S subsequence # Process D, I and S subsequence
else: else:
# Turn the processed sequence into edits # Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens) processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed: for seq in processed:
edits.append(seq) edits.append(seq)
filtered_edits = [] filtered_edits = []
i = 0 i = 0
while i < len(edits): while i < len(edits):
e1 = edits[i][0][0] e1 = edits[i][0][0]
if i < len(edits) - 2: if i < len(edits) - 2:
e2 = edits[i + 1][0][0] e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0] e3 = edits[i + 2][0][0]
# Find "S M S" patterns # Find "S M S" patterns
# Ex: # Ex:
# S M S # S M S
# 冬阴功 对 外国人 # 冬阴功 对 外国人
# 外国人 对 冬阴功 # 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S": if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]]) w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]]) w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]]) w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]) w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1: if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3: if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]] group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed: for seq in processed:
filtered_edits.append(seq) filtered_edits.append(seq)
i += 3 i += 3
else: else:
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
else: else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1: if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]] group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed: for seq in processed:
filtered_edits.append(seq) filtered_edits.append(seq)
i += 3 i += 3
else: else:
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
# Find "D M I" or "I M D" patterns # Find "D M I" or "I M D" patterns
# Ex: # Ex:
# D M I # D M I
# 旅游 去 陌生 的 地方 # 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游 # 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"): elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D": if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]] delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]] insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else: else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]] delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]] insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token) a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b): if len(a) < len(b):
a, b = b, a a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1: if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1: if len(b) == 1:
if a == b: if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]] group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed: for seq in processed:
filtered_edits.append(seq) filtered_edits.append(seq)
i += 3 i += 3
else: else:
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
else: else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)): if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]] group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed: for seq in processed:
filtered_edits.append(seq) filtered_edits.append(seq)
i += 3 i += 3
else: else:
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
else: else:
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
else: else:
if e1 != "M": if e1 != "M":
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
else: else:
if e1 != "M": if e1 != "M":
filtered_edits.append(edits[i]) filtered_edits.append(edits[i])
i += 1 i += 1
# In rare cases with word-level tokenization, the following error can occur: # In rare cases with word-level tokenization, the following error can occur:
# M D S M # M D S M
# 有 時 住 上層 # 有 時 住 上層
# 有 時住 上層 # 有 時住 上層
# Which results in S: 時住 --> 時住 # Which results in S: 時住 --> 時住
# We need to filter this case out # We need to filter this case out
second_filter = [] second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象 for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]]) span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]]) span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2: if span1 != span2:
if edit[0] == "S": if edit[0] == "S":
b = True b = True
# In rare cases with word-level tokenization, the following error can occur: # In rare cases with word-level tokenization, the following error can occur:
# S I I M # S I I M
# 负责任 老师 # 负责任 老师
# 负 责任 的 老师 # 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的 # Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的 # We need to convert this edit to I: --> 的
# 首部有重叠 # 首部有重叠
common_str = "" common_str = ""
tmp_new_start_1 = edit[1] tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]): for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]): if not span2.startswith(common_str + src_tokens[i]):
break break
common_str += src_tokens[i] common_str += src_tokens[i]
tmp_new_start_1 = i + 1 tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3] new_start_1, new_start_2 = edit[1], edit[3]
if common_str: if common_str:
tmp_str = "" tmp_str = ""
for i in range(edit[3], edit[4]): for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i] tmp_str += tgt_tokens[i]
if tmp_str == common_str: if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1 new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4])) # second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False b = False
break break
elif len(tmp_str) > len(common_str): elif len(tmp_str) > len(common_str):
break break
# 尾部有重叠 # 尾部有重叠
common_str = "" common_str = ""
new_end_1, new_end_2 = edit[2], edit[4] new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2] tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])): for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str): if not span2.endswith(src_tokens[i] + common_str):
break break
common_str = src_tokens[i] + common_str common_str = src_tokens[i] + common_str
tmp_new_end_1 = i tmp_new_end_1 = i
if common_str: if common_str:
tmp_str = "" tmp_str = ""
for i in reversed(range(new_start_2, edit[4])): for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str: if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i new_end_1, new_end_2 = tmp_new_end_1, i
b = False b = False
break break
elif len(tmp_str) > len(common_str): elif len(tmp_str) > len(common_str):
break break
if b: if b:
second_filter.append(edit) second_filter.append(edit)
else: else:
if new_start_1 == new_end_1: if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2) new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2: elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2) new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else: else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2) new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit) second_filter.append(new_edit)
else: else:
second_filter.append(edit) second_filter.append(edit)
if verbose: if verbose:
print("========== Parallels ==========") print("========== Parallels ==========")
print("".join(src_tokens)) print("".join(src_tokens))
print("".join(tgt_tokens)) print("".join(tgt_tokens))
print("========== Results ==========") print("========== Results ==========")
for edit in second_filter: for edit in second_filter:
op = edit[0] op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]]) s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]]) t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}") print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========") print("========== Infos ==========")
print(str(src)) print(str(src))
print(str(tgt)) print(str(tgt))
return second_filter return second_filter
if __name__ == "__main__": if __name__ == "__main__":
tokenizer = Tokenizer("char") tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin() semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion() confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict) alignment = Alignment(semantic_dict, confusion_dict)
sents = [ sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace( "所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""), " ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace( "所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")] " ", "")]
src, tgt = tokenizer(sents) src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt) align_obj = alignment(src, tgt)
m = Merger() m = Merger()
m(align_obj, src, tgt, verbose=True) m(align_obj, src, tgt, verbose=True)
\ No newline at end of file
from ltp import LTP from ltp import LTP
from typing import List from typing import List
from pypinyin import pinyin, Style, lazy_pinyin from pypinyin import pinyin, Style, lazy_pinyin
import torch import torch
import os import os
import functools import functools
class Tokenizer: class Tokenizer:
""" """
分词器 分词器
""" """
def __init__(self, def __init__(self,
granularity: str = "word", granularity: str = "word",
device: str = "cpu", device: str = "cpu",
segmented: bool = False, segmented: bool = False,
bpe: bool = False, bpe: bool = False,
) -> None: ) -> None:
""" """
构造函数 构造函数
:param mode: 分词模式,可选级别:字级别(char)、词级别(word) :param mode: 分词模式,可选级别:字级别(char)、词级别(word)
""" """
self.ltp = None self.ltp = None
if granularity == "word": if granularity == "word":
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu")) self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
self.ltp.add_words(words=["[缺失成分]"], max_window=6) self.ltp.add_words(words=["[缺失成分]"], max_window=6)
self.segmented = segmented self.segmented = segmented
self.granularity = granularity self.granularity = granularity
if self.granularity == "word": if self.granularity == "word":
self.tokenizer = self.split_word self.tokenizer = self.split_word
elif self.granularity == "char": elif self.granularity == "char":
self.tokenizer = functools.partial(self.split_char, bpe=bpe) self.tokenizer = functools.partial(self.split_char, bpe=bpe)
else: else:
raise NotImplementedError raise NotImplementedError
def __repr__(self) -> str: def __repr__(self) -> str:
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode) return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
def __call__(self, def __call__(self,
input_strings: List[str] input_strings: List[str]
) -> List: ) -> List:
""" """
分词函数 分词函数
:param input_strings: 需要分词的字符串列表 :param input_strings: 需要分词的字符串列表
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式 :return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
""" """
if not self.segmented: if not self.segmented:
input_strings = ["".join(s.split(" ")) for s in input_strings] input_strings = ["".join(s.split(" ")) for s in input_strings]
results = self.tokenizer(input_strings) results = self.tokenizer(input_strings)
return results return results
def split_char(self, input_strings: List[str], bpe=False) -> List: def split_char(self, input_strings: List[str], bpe=False) -> List:
""" """
分字函数 分字函数
:param input_strings: 需要分字的字符串 :param input_strings: 需要分字的字符串
:return: 分字结果 :return: 分字结果
""" """
if bpe: if bpe:
from . import tokenization from . import tokenization
project_dir = os.path.dirname(os.path.dirname(__file__)) project_dir = os.path.dirname(os.path.dirname(__file__))
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False) tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
results = [] results = []
for input_string in input_strings: for input_string in input_strings:
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果 if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string)) segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
else: else:
segment_string = input_string segment_string = input_string
# print(segment_string) # print(segment_string)
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string]) results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
return results return results
def split_word(self, input_strings: List[str]) -> List: def split_word(self, input_strings: List[str]) -> List:
""" """
分词函数 分词函数
:param input_strings: 需要分词的字符串 :param input_strings: 需要分词的字符串
:return: 分词结果 :return: 分词结果
""" """
if self.segmented: if self.segmented:
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True) seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
else: else:
seg, hidden = self.ltp.seg(input_strings) seg, hidden = self.ltp.seg(input_strings)
pos = self.ltp.pos(hidden) pos = self.ltp.pos(hidden)
result = [] result = []
for s, p in zip(seg, pos): for s, p in zip(seg, pos):
pinyin = [lazy_pinyin(word) for word in s] pinyin = [lazy_pinyin(word) for word in s]
result.append(list(zip(s, p, pinyin))) result.append(list(zip(s, p, pinyin)))
return result return result
if __name__ == "__main__": if __name__ == "__main__":
tokenizer = Tokenizer("word") tokenizer = Tokenizer("word")
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"])) print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
import os import os
from modules.annotator import Annotator from modules.annotator import Annotator
from modules.tokenizer import Tokenizer from modules.tokenizer import Tokenizer
import argparse import argparse
from collections import Counter from collections import Counter
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from collections import defaultdict from collections import defaultdict
from multiprocessing import Pool from multiprocessing import Pool
from opencc import OpenCC from opencc import OpenCC
import timeout_decorator import timeout_decorator
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
annotator, sentence_to_tokenized = None, None annotator, sentence_to_tokenized = None, None
cc = OpenCC("t2s") cc = OpenCC("t2s")
@timeout_decorator.timeout(10) @timeout_decorator.timeout(10)
def annotate_with_time_out(line): def annotate_with_time_out(line):
""" """
:param line: :param line:
:return: :return:
""" """
sent_list = line.split("\t")[1:] sent_list = line.split("\t")[1:]
source = sent_list[0] source = sent_list[0]
if args.segmented: if args.segmented:
source = source.strip() source = source.strip()
else: else:
source = "".join(source.strip().split()) source = "".join(source.strip().split())
output_str = "" output_str = ""
for idx, target in enumerate(sent_list[1:]): for idx, target in enumerate(sent_list[1:]):
try: try:
if args.segmented: if args.segmented:
target = target.strip() target = target.strip()
else: else:
target = "".join(target.strip().split()) target = "".join(target.strip().split())
if not args.no_simplified: if not args.no_simplified:
target = cc.convert(target) target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target] source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx) out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0: if idx == 0:
output_str += "".join(out[:-1]) output_str += "".join(out[:-1])
else: else:
output_str += "".join(out[1:-1]) output_str += "".join(out[1:-1])
except Exception: except Exception:
raise Exception raise Exception
return output_str return output_str
def annotate(line): def annotate(line):
""" """
:param line: :param line:
:return: :return:
""" """
sent_list = line.split("\t")[1:] sent_list = line.split("\t")[1:]
source = sent_list[0] source = sent_list[0]
if args.segmented: if args.segmented:
source = source.strip() source = source.strip()
else: else:
source = "".join(source.strip().split()) source = "".join(source.strip().split())
output_str = "" output_str = ""
for idx, target in enumerate(sent_list[1:]): for idx, target in enumerate(sent_list[1:]):
try: try:
if args.segmented: if args.segmented:
target = target.strip() target = target.strip()
else: else:
target = "".join(target.strip().split()) target = "".join(target.strip().split())
if not args.no_simplified: if not args.no_simplified:
target = cc.convert(target) target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target] source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx) out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0: if idx == 0:
output_str += "".join(out[:-1]) output_str += "".join(out[:-1])
else: else:
output_str += "".join(out[1:-1]) output_str += "".join(out[1:-1])
except Exception: except Exception:
raise Exception raise Exception
return output_str return output_str
def firsttime_process(args): def firsttime_process(args):
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe) tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy) annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2... lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
# error_types = [] # error_types = []
with open(args.output, "w", encoding="utf-8") as f: with open(args.output, "w", encoding="utf-8") as f:
count = 0 count = 0
sentence_set = set() sentence_set = set()
sentence_to_tokenized = {} sentence_to_tokenized = {}
for line in lines: for line in lines:
sent_list = line.split("\t")[1:] sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list): for idx, sent in enumerate(sent_list):
if args.segmented: if args.segmented:
# print(sent) # print(sent)
sent = sent.strip() sent = sent.strip()
else: else:
sent = "".join(sent.split()).strip() sent = "".join(sent.split()).strip()
if idx >= 1: if idx >= 1:
if not args.no_simplified: if not args.no_simplified:
sentence_set.add(cc.convert(sent)) sentence_set.add(cc.convert(sent))
else: else:
sentence_set.add(sent) sentence_set.add(sent)
else: else:
sentence_set.add(sent) sentence_set.add(sent)
batch = [] batch = []
for sent in tqdm(sentence_set): for sent in tqdm(sentence_set):
count += 1 count += 1
if sent: if sent:
batch.append(sent) batch.append(sent)
if count % args.batch_size == 0: if count % args.batch_size == 0:
results = tokenizer(batch) results = tokenizer(batch)
for s, r in zip(batch, results): for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map. sentence_to_tokenized[s] = r # Get tokenization map.
batch = [] batch = []
if batch: if batch:
results = tokenizer(batch) results = tokenizer(batch)
for s, r in zip(batch, results): for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map. sentence_to_tokenized[s] = r # Get tokenization map.
timeout_indices = [] timeout_indices = []
# 单进程模式 # 单进程模式
for idx, line in enumerate(tqdm(lines)): for idx, line in enumerate(tqdm(lines)):
try: try:
ret = annotate_with_time_out(line) ret = annotate_with_time_out(line)
except Exception: except Exception:
timeout_indices.append(idx) timeout_indices.append(idx)
return timeout_indices return timeout_indices
def main(args): def main(args):
timeout_indices = firsttime_process(args) timeout_indices = firsttime_process(args)
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe) tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy) annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
new_lines = []# format: id src tgt1 tgt2... new_lines = []# format: id src tgt1 tgt2...
with open(args.output, "w", encoding="utf-8") as f: with open(args.output, "w", encoding="utf-8") as f:
count = 0 count = 0
sentence_set = set() sentence_set = set()
sentence_to_tokenized = {} sentence_to_tokenized = {}
for line_idx, line in enumerate(lines): for line_idx, line in enumerate(lines):
if line_idx in timeout_indices: if line_idx in timeout_indices:
# print(f"line before split: {line}") # print(f"line before split: {line}")
line_split = line.split("\t") line_split = line.split("\t")
line_number, sent_list = line_split[0], line_split[1:] line_number, sent_list = line_split[0], line_split[1:]
assert len(sent_list) == 2 assert len(sent_list) == 2
sent_list[-1] = " 无" sent_list[-1] = " 无"
line = line_number + "\t" + "\t".join(sent_list) line = line_number + "\t" + "\t".join(sent_list)
# print(f"line time out: {line}") # print(f"line time out: {line}")
new_lines.append(line) new_lines.append(line)
else: else:
new_lines.append(line) new_lines.append(line)
sent_list = line.split("\t")[1:] sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list): for idx, sent in enumerate(sent_list):
if args.segmented: if args.segmented:
# print(sent) # print(sent)
sent = sent.strip() sent = sent.strip()
else: else:
sent = "".join(sent.split()).strip() sent = "".join(sent.split()).strip()
if idx >= 1: if idx >= 1:
if not args.no_simplified: if not args.no_simplified:
sentence_set.add(cc.convert(sent)) sentence_set.add(cc.convert(sent))
else: else:
sentence_set.add(sent) sentence_set.add(sent)
else: else:
sentence_set.add(sent) sentence_set.add(sent)
batch = [] batch = []
for sent in tqdm(sentence_set): for sent in tqdm(sentence_set):
count += 1 count += 1
if sent: if sent:
batch.append(sent) batch.append(sent)
if count % args.batch_size == 0: if count % args.batch_size == 0:
results = tokenizer(batch) results = tokenizer(batch)
for s, r in zip(batch, results): for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map. sentence_to_tokenized[s] = r # Get tokenization map.
batch = [] batch = []
if batch: if batch:
results = tokenizer(batch) results = tokenizer(batch)
for s, r in zip(batch, results): for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map. sentence_to_tokenized[s] = r # Get tokenization map.
# 单进程模式 # 单进程模式
lines = new_lines lines = new_lines
for idx, line in enumerate(tqdm(lines)): for idx, line in enumerate(tqdm(lines)):
ret = annotate(line) ret = annotate(line)
f.write(ret) f.write(ret)
f.write("\n") f.write("\n")
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用 # 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# with Pool(args.worker_num) as pool: # with Pool(args.worker_num) as pool:
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8): # for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# if ret: # if ret:
# f.write(ret) # f.write(ret)
# f.write("\n") # f.write("\n")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Choose input file to annotate") parser = argparse.ArgumentParser(description="Choose input file to annotate")
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file") parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
parser.add_argument("-o", "--output", type=str, help="Output file", required=True) parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128) parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0) parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16) parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char") parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true") parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all") parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开 parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文 parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词 parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import datetime
import os import os
import os.path as osp import os.path as osp
import random import random
import re
import subprocess import subprocess
import sys
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import mmengine import mmengine
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
...@@ -43,6 +46,11 @@ class DLCRunner(BaseRunner): ...@@ -43,6 +46,11 @@ class DLCRunner(BaseRunner):
self.max_num_workers = max_num_workers self.max_num_workers = max_num_workers
self.retry = retry self.retry = retry
logger = get_logger()
logger.warning(
'To ensure the integrity of the log results, the log displayed '
f'by {self.__class__.__name__} has a 10-second delay.')
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks. """Launch multiple tasks.
...@@ -63,18 +71,23 @@ class DLCRunner(BaseRunner): ...@@ -63,18 +71,23 @@ class DLCRunner(BaseRunner):
status = [self._launch(task, random_sleep=False) for task in tasks] status = [self._launch(task, random_sleep=False) for task in tasks]
return status return status
def _launch(self, cfg: ConfigDict, random_sleep: bool = True): def _launch(self, cfg: ConfigDict, random_sleep: Optional[bool] = None):
"""Launch a single task. """Launch a single task.
Args: Args:
cfg (ConfigDict): Task config. cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching running the command. When Aliyun has many tasks to schedule,
multiple tasks at the same time. Default: True. its stability decreases. Therefore, when we need to submit a
large number of tasks at once, we adopt the "random_sleep"
strategy. Tasks that would have been submitted all at once are
now evenly spread out over a 10-second period. Default: None.
Returns: Returns:
tuple[str, int]: Task name and exit code. tuple[str, int]: Task name and exit code.
""" """
if random_sleep is None:
random_sleep = (self.max_num_workers > 32)
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
num_gpus = task.num_gpus num_gpus = task.num_gpus
...@@ -116,7 +129,7 @@ class DLCRunner(BaseRunner): ...@@ -116,7 +129,7 @@ class DLCRunner(BaseRunner):
# Run command with retry # Run command with retry
if self.debug: if self.debug:
stdout = None stdout = sys.stdout
else: else:
out_path = task.get_log_path(file_extension='out') out_path = task.get_log_path(file_extension='out')
mmengine.mkdir_or_exist(osp.split(out_path)[0]) mmengine.mkdir_or_exist(osp.split(out_path)[0])
...@@ -124,30 +137,92 @@ class DLCRunner(BaseRunner): ...@@ -124,30 +137,92 @@ class DLCRunner(BaseRunner):
if random_sleep: if random_sleep:
time.sleep(random.randint(0, 10)) time.sleep(random.randint(0, 10))
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
def _run_within_retry():
try:
process = subprocess.Popen(cmd,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
job_id = None
job_allocated = False
job_finished = False
last_end_time = datetime.datetime.now().strftime(
'%Y-%m-%dT%H:%M:%SZ')
while True:
if not job_allocated:
line = process.stdout.readline()
if not line:
break
match = re.search(r'(dlc[0-9a-z]+)', line)
if match and job_id is None:
job_id = match.group(1)
stdout.write(line)
match = re.search(r'Job .* is \[Running\]', line)
if match:
job_allocated = True
else:
try:
process.wait(10)
except subprocess.TimeoutExpired:
pass
else:
job_finished = True
if job_finished:
this_end_time = datetime.datetime.now(
).strftime('%Y-%m-%dT%H:%M:%SZ')
else:
this_end_time = (
datetime.datetime.now() -
datetime.timedelta(seconds=10)
).strftime('%Y-%m-%dT%H:%M:%SZ')
logs_cmd = (
'dlc logs'
f' {job_id} {job_id}-worker-0'
f' --start_time {last_end_time}'
f' --end_time {this_end_time}'
f" -c {self.aliyun_cfg['dlc_config_path']}")
log_process = subprocess.Popen(
logs_cmd,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
log_output, log_err = log_process.communicate()
log_output = '\n'.join(log_output.split('\n')[2:])
stdout.write(log_output)
last_end_time = this_end_time
stdout.flush()
if job_finished:
break
process.wait()
return process.returncode
finally:
if job_id is not None:
cancel_cmd = (
'dlc stop job'
f' {job_id}'
f" -c {self.aliyun_cfg['dlc_config_path']}"
' -f')
subprocess.run(cancel_cmd,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return_code = _run_within_retry()
retry = self.retry retry = self.retry
output_paths = task.get_output_paths() output_paths = task.get_output_paths()
while self._job_failed(result.returncode, while self._job_failed(return_code, output_paths) and retry > 0:
output_paths) and retry > 0:
retry -= 1 retry -= 1
if random_sleep:
time.sleep(random.randint(0, 10))
# Re-generate command to refresh ports.
cmd = get_cmd() cmd = get_cmd()
result = subprocess.run(cmd, return_code = _run_within_retry()
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
finally: finally:
# Clean up # Clean up
os.remove(param_file) os.remove(param_file)
return task_name, result.returncode
return task_name, return_code
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool: def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
return return_code != 0 or not all( return return_code != 0 or not all(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment