"configs/datasets/vscode:/vscode.git/clone" did not exist on "f480b72703b3a82e98b29480750d10ccdb8e7f49"
Unverified Commit 4dd9a3fc authored by Leymore's avatar Leymore Committed by GitHub
Browse files

[Sync] sync with internal codes 20231019 (#488)

parent 2737249f
import argparse
from collections import Counter
def main():
# Parse command line args
args = parse_args()
# Open hypothesis and reference m2 files and split into chunks
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
# Make sure they have the same number of sentences
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
# Store global corpus level best counts here
best_dict = Counter({"tp":0, "fp":0, "fn":0})
best_cats = {}
# Process each sentence
sents = zip(hyp_m2, ref_m2)
for sent_id, sent in enumerate(sents):
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
src = sent[0].split("\n")[0]
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
ref_edits = simplify_edits(sent[1], args.max_answer_num)
# Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args)
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(src,
hyp_dict, ref_dict, best_dict, sent_id, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
# Print results
print_results(best_dict, best_cats, args)
# Parse command line args
def parse_args():
parser = argparse.ArgumentParser(
description="Calculate F-scores for error detection and/or correction.\n"
"Flags let you evaluate at different levels of granularity.",
formatter_class=argparse.RawTextHelpFormatter,
usage="%(prog)s [options] -hyp HYP -ref REF")
parser.add_argument(
"-hyp",
help="A hypothesis M2 file.",
required=True)
parser.add_argument(
"-ref",
help="A reference M2 file.",
required=True)
parser.add_argument(
"--start",
type=int,
default=None
)
parser.add_argument(
"--end",
type=int,
default=None
)
parser.add_argument(
"--max_answer_num",
type=int,
default=None
)
parser.add_argument(
"--reference_num",
type=int,
default=None
)
parser.add_argument(
"-b",
"--beta",
help="Value of beta in F-score. (default: 0.5)",
default=0.5,
type=float)
parser.add_argument(
"-v",
"--verbose",
help="Print verbose output.",
action="store_true")
eval_type = parser.add_mutually_exclusive_group()
eval_type.add_argument(
"-dt",
help="Evaluate Detection in terms of Tokens.",
action="store_true")
eval_type.add_argument(
"-ds",
help="Evaluate Detection in terms of Spans.",
action="store_true")
eval_type.add_argument(
"-cs",
help="Evaluate Correction in terms of Spans. (default)",
action="store_true")
eval_type.add_argument(
"-cse",
help="Evaluate Correction in terms of Spans and Error types.",
action="store_true")
parser.add_argument(
"-single",
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
action="store_true")
parser.add_argument(
"-multi",
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
action="store_true")
parser.add_argument(
"-multi_hyp_avg",
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
action="store_true") # For IAA calculation
parser.add_argument(
"-multi_hyp_max",
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
action="store_true") # For multiple hypotheses system evaluation
parser.add_argument(
"-filt",
help="Do not evaluate the specified error types.",
nargs="+",
default=[])
parser.add_argument(
"-cat",
help="Show error category scores.\n"
"1: Only show operation tier scores; e.g. R.\n"
"2: Only show main tier scores; e.g. NOUN.\n"
"3: Show all category scores; e.g. R:NOUN.",
choices=[1, 2, 3],
type=int)
args = parser.parse_args()
return args
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def simplify_edits(sent, max_answer_num):
out_edits = []
# Get the edit lines from an m2 block.
edits = sent.split("\n")
# Loop through the edits
for edit in edits:
# Preprocessing
if edit.startswith("A "):
edit = edit[2:].split("|||") # Ignore "A " then split.
span = edit[0].split()
start = int(span[0])
end = int(span[1])
cat = edit[1]
cor = edit[2].replace(" ", "")
coder = int(edit[-1])
out_edit = [start, end, cat, cor, coder]
out_edits.append(out_edit)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if max_answer_num is None:
return out_edits
elif max_answer_num == 1:
return [edit for edit in out_edits if edit[-1] == 0]
elif max_answer_num == 2:
return [edit for edit in out_edits if edit[-1] in [0, 1]]
elif max_answer_num == 3:
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
def process_edits(edits, args):
coder_dict = {}
# Add an explicit noop edit if there are no edits.
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
# Loop through the edits
for edit in edits:
# Name the edit elements for clarity
start = edit[0]
end = edit[1]
cat = edit[2]
cor = edit[3]
coder = edit[4]
# Add the coder to the coder_dict if necessary
if coder not in coder_dict: coder_dict[coder] = {}
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
if not args.dt and not args.ds and cat == "UNK": continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
# 4. If there is a filter, ignore the specified error types
if args.filt and cat in args.filt: continue
# Token Based Detection
if args.dt:
# Preserve noop edits.
if start == -1:
if (start, start) in coder_dict[coder].keys():
coder_dict[coder][(start, start)].append(cat)
else:
coder_dict[coder][(start, start)] = [cat]
# Insertions defined as affecting the token on the right
elif start == end and start >= 0:
if (start, start+1) in coder_dict[coder].keys():
coder_dict[coder][(start, start+1)].append(cat)
else:
coder_dict[coder][(start, start+1)] = [cat]
# Edit spans are split for each token in the range.
else:
for tok_id in range(start, end):
if (tok_id, tok_id+1) in coder_dict[coder].keys():
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
else:
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
# Span Based Detection
elif args.ds:
if (start, end) in coder_dict[coder].keys():
coder_dict[coder][(start, end)].append(cat)
else:
coder_dict[coder][(start, end)] = [cat]
# Span Based Correction
else:
# With error type classification
if args.cse:
if (start, end, cat, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cat, cor)].append(cat)
else:
coder_dict[coder][(start, end, cat, cor)] = [cat]
# Without error type classification
else:
if (start, end, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cor)].append(cat)
else:
coder_dict[coder][(start, end, cor)] = [cat]
return coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
best_cat = {}
# skip not annotatable sentence
if len(ref_dict.keys()) == 1:
ref_id = list(ref_dict.keys())[0]
if len(ref_dict[ref_id].keys()) == 1:
cat = list(ref_dict[ref_id].values())[0][0]
if cat == "NA":
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Compare each hyp and ref combination
for hyp_id in hyp_dict.keys():
for ref_id in ref_dict.keys():
# Get the local counts for the current combination.
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
# Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
if (f > best_f) or \
(f == best_f and tp > best_tp) or \
(f == best_f and tp == best_tp and fp < best_fp) or \
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
best_tp, best_fp, best_fn = tp, fp, fn
best_f, best_hyp, best_ref = f, hyp_id, ref_id
best_cat = cat_dict
# Verbose output
if args.verbose:
# Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys()))
# Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
# Print verbose info
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+src[1:])
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
print("HYPOTHESIS EDITS :", hyp_verb)
print("REFERENCE EDITS :", ref_verb)
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
# Verbose output: display the best hyp+ref combination
if args.verbose:
print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives
fp = 0 # False Positives
fn = 0 # False Negatives
cat_dict = {} # {cat: [tp, fp, fn], ...}
for h_edit, h_cats in hyp_edits.items():
# noop hyp edits cannot be TP or FP
if h_cats[0] == "noop": continue
# TRUE POSITIVES
if h_edit in ref_edits.keys():
# On occasion, multiple tokens at same span.
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
tp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][0] += 1
else:
cat_dict[h_cat] = [1, 0, 0]
# FALSE POSITIVES
else:
# On occasion, multiple tokens at same span.
for h_cat in h_cats:
fp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][1] += 1
else:
cat_dict[h_cat] = [0, 1, 0]
for r_edit, r_cats in ref_edits.items():
# noop ref edits cannot be FN
if r_cats[0] == "noop": continue
# FALSE NEGATIVES
if r_edit not in hyp_edits.keys():
# On occasion, multiple tokens at same span.
for r_cat in r_cats:
fn += 1
# Each dict value [TP, FP, FN]
if r_cat in cat_dict.keys():
cat_dict[r_cat][2] += 1
else:
cat_dict[r_cat] = [0, 0, 1]
return tp, fp, fn, cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def computeFScore(tp, fp, fn, beta):
p = float(tp)/(tp+fp) if fp else 1.0
r = float(tp)/(tp+fn) if fn else 1.0
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2):
for cat, stats in dict2.items():
if cat in dict1.keys():
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
else:
dict1[cat] = stats
return dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def processCategories(cat_dict, setting):
# Otherwise, do some processing.
proc_cat_dict = {}
for cat, cnt in cat_dict.items():
if cat == "UNK":
proc_cat_dict[cat] = cnt
continue
# M, U, R or UNK combined only.
if setting == 1:
if cat[0] in proc_cat_dict.keys():
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
else:
proc_cat_dict[cat[0]] = cnt
# Everything without M, U or R.
elif setting == 2:
if cat[2:] in proc_cat_dict.keys():
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
else:
proc_cat_dict[cat[2:]] = cnt
# All error category combinations
else:
return cat_dict
return proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
def print_results(best, best_cats, args):
# Prepare output title.
if args.dt: title = " Token-Based Detection "
elif args.ds: title = " Span-Based Detection "
elif args.cse: title = " Span-Based Correction + Classification "
else: title = " Span-Based Correction "
# Category Scores
if args.cat:
best_cats = processCategories(best_cats, args.cat)
print("")
print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
print("")
if __name__ == "__main__":
# Run the program
main()
import argparse
from collections import Counter
def main():
# Parse command line args
args = parse_args()
# Open hypothesis and reference m2 files and split into chunks
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
# Make sure they have the same number of sentences
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
# Store global corpus level best counts here
best_dict = Counter({"tp":0, "fp":0, "fn":0})
best_cats = {}
# Process each sentence
sents = zip(hyp_m2, ref_m2)
for sent_id, sent in enumerate(sents):
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
src = sent[0].split("\n")[0]
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
ref_edits = simplify_edits(sent[1], args.max_answer_num)
# Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args)
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(src,
hyp_dict, ref_dict, best_dict, sent_id, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
# Print results
print_results(best_dict, best_cats, args)
# Parse command line args
def parse_args():
parser = argparse.ArgumentParser(
description="Calculate F-scores for error detection and/or correction.\n"
"Flags let you evaluate at different levels of granularity.",
formatter_class=argparse.RawTextHelpFormatter,
usage="%(prog)s [options] -hyp HYP -ref REF")
parser.add_argument(
"-hyp",
help="A hypothesis M2 file.",
required=True)
parser.add_argument(
"-ref",
help="A reference M2 file.",
required=True)
parser.add_argument(
"--start",
type=int,
default=None
)
parser.add_argument(
"--end",
type=int,
default=None
)
parser.add_argument(
"--max_answer_num",
type=int,
default=None
)
parser.add_argument(
"--reference_num",
type=int,
default=None
)
parser.add_argument(
"-b",
"--beta",
help="Value of beta in F-score. (default: 0.5)",
default=0.5,
type=float)
parser.add_argument(
"-v",
"--verbose",
help="Print verbose output.",
action="store_true")
eval_type = parser.add_mutually_exclusive_group()
eval_type.add_argument(
"-dt",
help="Evaluate Detection in terms of Tokens.",
action="store_true")
eval_type.add_argument(
"-ds",
help="Evaluate Detection in terms of Spans.",
action="store_true")
eval_type.add_argument(
"-cs",
help="Evaluate Correction in terms of Spans. (default)",
action="store_true")
eval_type.add_argument(
"-cse",
help="Evaluate Correction in terms of Spans and Error types.",
action="store_true")
parser.add_argument(
"-single",
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
action="store_true")
parser.add_argument(
"-multi",
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
action="store_true")
parser.add_argument(
"-multi_hyp_avg",
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
action="store_true") # For IAA calculation
parser.add_argument(
"-multi_hyp_max",
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
action="store_true") # For multiple hypotheses system evaluation
parser.add_argument(
"-filt",
help="Do not evaluate the specified error types.",
nargs="+",
default=[])
parser.add_argument(
"-cat",
help="Show error category scores.\n"
"1: Only show operation tier scores; e.g. R.\n"
"2: Only show main tier scores; e.g. NOUN.\n"
"3: Show all category scores; e.g. R:NOUN.",
choices=[1, 2, 3],
type=int)
args = parser.parse_args()
return args
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def simplify_edits(sent, max_answer_num):
out_edits = []
# Get the edit lines from an m2 block.
edits = sent.split("\n")
# Loop through the edits
for edit in edits:
# Preprocessing
if edit.startswith("A "):
edit = edit[2:].split("|||") # Ignore "A " then split.
span = edit[0].split()
start = int(span[0])
end = int(span[1])
cat = edit[1]
cor = edit[2].replace(" ", "")
coder = int(edit[-1])
out_edit = [start, end, cat, cor, coder]
out_edits.append(out_edit)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if max_answer_num is None:
return out_edits
elif max_answer_num == 1:
return [edit for edit in out_edits if edit[-1] == 0]
elif max_answer_num == 2:
return [edit for edit in out_edits if edit[-1] in [0, 1]]
elif max_answer_num == 3:
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
def process_edits(edits, args):
coder_dict = {}
# Add an explicit noop edit if there are no edits.
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
# Loop through the edits
for edit in edits:
# Name the edit elements for clarity
start = edit[0]
end = edit[1]
cat = edit[2]
cor = edit[3]
coder = edit[4]
# Add the coder to the coder_dict if necessary
if coder not in coder_dict: coder_dict[coder] = {}
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
if not args.dt and not args.ds and cat == "UNK": continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
# 4. If there is a filter, ignore the specified error types
if args.filt and cat in args.filt: continue
# Token Based Detection
if args.dt:
# Preserve noop edits.
if start == -1:
if (start, start) in coder_dict[coder].keys():
coder_dict[coder][(start, start)].append(cat)
else:
coder_dict[coder][(start, start)] = [cat]
# Insertions defined as affecting the token on the right
elif start == end and start >= 0:
if (start, start+1) in coder_dict[coder].keys():
coder_dict[coder][(start, start+1)].append(cat)
else:
coder_dict[coder][(start, start+1)] = [cat]
# Edit spans are split for each token in the range.
else:
for tok_id in range(start, end):
if (tok_id, tok_id+1) in coder_dict[coder].keys():
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
else:
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
# Span Based Detection
elif args.ds:
if (start, end) in coder_dict[coder].keys():
coder_dict[coder][(start, end)].append(cat)
else:
coder_dict[coder][(start, end)] = [cat]
# Span Based Correction
else:
# With error type classification
if args.cse:
if (start, end, cat, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cat, cor)].append(cat)
else:
coder_dict[coder][(start, end, cat, cor)] = [cat]
# Without error type classification
else:
if (start, end, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cor)].append(cat)
else:
coder_dict[coder][(start, end, cor)] = [cat]
return coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
best_cat = {}
# skip not annotatable sentence
if len(ref_dict.keys()) == 1:
ref_id = list(ref_dict.keys())[0]
if len(ref_dict[ref_id].keys()) == 1:
cat = list(ref_dict[ref_id].values())[0][0]
if cat == "NA":
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Compare each hyp and ref combination
for hyp_id in hyp_dict.keys():
for ref_id in ref_dict.keys():
# Get the local counts for the current combination.
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
# Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
if (f > best_f) or \
(f == best_f and tp > best_tp) or \
(f == best_f and tp == best_tp and fp < best_fp) or \
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
best_tp, best_fp, best_fn = tp, fp, fn
best_f, best_hyp, best_ref = f, hyp_id, ref_id
best_cat = cat_dict
# Verbose output
if args.verbose:
# Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys()))
# Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
# Print verbose info
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+src[1:])
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
print("HYPOTHESIS EDITS :", hyp_verb)
print("REFERENCE EDITS :", ref_verb)
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
# Verbose output: display the best hyp+ref combination
if args.verbose:
print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives
fp = 0 # False Positives
fn = 0 # False Negatives
cat_dict = {} # {cat: [tp, fp, fn], ...}
for h_edit, h_cats in hyp_edits.items():
# noop hyp edits cannot be TP or FP
if h_cats[0] == "noop": continue
# TRUE POSITIVES
if h_edit in ref_edits.keys():
# On occasion, multiple tokens at same span.
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
tp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][0] += 1
else:
cat_dict[h_cat] = [1, 0, 0]
# FALSE POSITIVES
else:
# On occasion, multiple tokens at same span.
for h_cat in h_cats:
fp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][1] += 1
else:
cat_dict[h_cat] = [0, 1, 0]
for r_edit, r_cats in ref_edits.items():
# noop ref edits cannot be FN
if r_cats[0] == "noop": continue
# FALSE NEGATIVES
if r_edit not in hyp_edits.keys():
# On occasion, multiple tokens at same span.
for r_cat in r_cats:
fn += 1
# Each dict value [TP, FP, FN]
if r_cat in cat_dict.keys():
cat_dict[r_cat][2] += 1
else:
cat_dict[r_cat] = [0, 0, 1]
return tp, fp, fn, cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def computeFScore(tp, fp, fn, beta):
p = float(tp)/(tp+fp) if fp else 1.0
r = float(tp)/(tp+fn) if fn else 1.0
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2):
for cat, stats in dict2.items():
if cat in dict1.keys():
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
else:
dict1[cat] = stats
return dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def processCategories(cat_dict, setting):
# Otherwise, do some processing.
proc_cat_dict = {}
for cat, cnt in cat_dict.items():
if cat == "UNK":
proc_cat_dict[cat] = cnt
continue
# M, U, R or UNK combined only.
if setting == 1:
if cat[0] in proc_cat_dict.keys():
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
else:
proc_cat_dict[cat[0]] = cnt
# Everything without M, U or R.
elif setting == 2:
if cat[2:] in proc_cat_dict.keys():
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
else:
proc_cat_dict[cat[2:]] = cnt
# All error category combinations
else:
return cat_dict
return proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
def print_results(best, best_cats, args):
# Prepare output title.
if args.dt: title = " Token-Based Detection "
elif args.ds: title = " Span-Based Detection "
elif args.cse: title = " Span-Based Correction + Classification "
else: title = " Span-Based Correction "
# Category Scores
if args.cat:
best_cats = processCategories(best_cats, args.cat)
print("")
print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
print("")
if __name__ == "__main__":
# Run the program
main()
from rouge_chinese import Rouge
import jieba
from nltk.translate.gleu_score import corpus_gleu
def compute_f1_two_sets(pred_set, gt_set):
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return f1
def multi_choice_judge(prediction, option_list, answer_token):
# a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict, abstention, accuracy = {}, 0, 0
for option in option_list:
option_count = prediction.count(option)
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
if sum(count_dict.values()) == 0:
abstention = 1
# if the answer token is the only predicted token, the prediction is correct
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
accuracy = 1
return {"score": accuracy, "abstention": abstention}
"""
compute the rouge score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def compute_rouge(hyps, refs):
assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [' '.join(jieba.cut(r)) for r in refs]
return Rouge().get_scores(hyps, refs)
"""
compute the gleu score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def compute_gleu(hyps, refs):
assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [[' '.join(jieba.cut(r))] for r in refs]
return corpus_gleu(refs, hyps)
from rouge_chinese import Rouge
import jieba
from nltk.translate.gleu_score import corpus_gleu
def compute_f1_two_sets(pred_set, gt_set):
precision = len(pred_set.intersection(gt_set)) / len(pred_set) if len(pred_set) > 0 else 0
recall = len(pred_set.intersection(gt_set)) / len(gt_set) if len(gt_set) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
return f1
def multi_choice_judge(prediction, option_list, answer_token):
# a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict, abstention, accuracy = {}, 0, 0
for option in option_list:
option_count = prediction.count(option)
count_dict[option] = 1 if option_count > 0 else 0 # multiple occurrence of the same letter is counted as 1
if sum(count_dict.values()) == 0:
abstention = 1
# if the answer token is the only predicted token, the prediction is correct
elif count_dict[answer_token] == 1 and sum(count_dict.values()) == 1:
accuracy = 1
return {"score": accuracy, "abstention": abstention}
"""
compute the rouge score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def compute_rouge(hyps, refs):
assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [' '.join(jieba.cut(r)) for r in refs]
return Rouge().get_scores(hyps, refs)
"""
compute the gleu score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def compute_gleu(hyps, refs):
assert(len(hyps) == len(refs))
hyps = [' '.join(jieba.cut(h)) for h in hyps]
hyps = [h if h.strip() != "" else "无内容" for h in hyps]
refs = [[' '.join(jieba.cut(r))] for r in refs]
return corpus_gleu(refs, hyps)
import numpy as np
from typing import List, Tuple, Dict
from modules.tokenizer import Tokenizer
import os
from string import punctuation
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct = punctuation
punct = chinese_punct + english_punct
def check_all_chinese(word):
"""
判断一个单词是否全部由中文组成
:param word:
:return:
"""
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
def read_cilin():
"""
Cilin 詞林 is a thesaurus with semantic information
"""
# TODO -- fix this path
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
semantic_dict = {}
semantic_classes = {}
for line in lines:
code, *words = line.split(" ")
for word in words:
semantic_dict[word] = code
# make reverse dict
if code in semantic_classes:
semantic_classes[code] += words
else:
semantic_classes[code] = words
return semantic_dict, semantic_classes
def read_confusion():
confusion_dict = {}
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
for line in f:
li = line.rstrip('\n').split(" ")
confusion_dict[li[0]] = li[1:]
return confusion_dict
class Alignment:
"""
对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作
"""
def __init__(
self,
semantic_dict: Dict,
confusion_dict: Dict,
granularity: str = "word",
) -> None:
"""
构造函数
:param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集
"""
self.insertion_cost = 1
self.deletion_cost = 1
self.semantic_dict = semantic_dict
self.confusion_dict = confusion_dict
# Because we use character level tokenization, this doesn't currently use POS
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
self.granularity = granularity # word-level or character-level
self.align_seqs = []
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
verbose: bool = False):
cost_matrix, oper_matrix = self.align(src, tgt)
align_seq = self.get_cheapest_align_seq(oper_matrix)
if verbose:
print("========== Seg. and POS: ==========")
print(src)
print(tgt)
print("========== Cost Matrix ==========")
print(cost_matrix)
print("========== Oper Matrix ==========")
print(oper_matrix)
print("========== Alignment ==========")
print(align_seq)
print("========== Results ==========")
for a in align_seq:
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
return align_seq
def _get_semantic_class(self, word):
"""
NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别)
"""
if word in self.semantic_dict:
code = self.semantic_dict[word]
high, mid, low = code[0], code[1], code[2:4]
return high, mid, low
else: # unknown
return None
@staticmethod
def _get_class_diff(a_class, b_class):
"""
d == 3 for equivalent semantics
d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距
"""
d = sum([a == b for a, b in zip(a_class, b_class)])
return d
def _get_semantic_cost(self, a, b):
"""
计算基于语义信息的替换操作cost
:param a: 单词a的语义类别
:param b: 单词b的语义类别
:return: 替换编辑代价
"""
a_class = self._get_semantic_class(a)
b_class = self._get_semantic_class(b)
# unknown class, default to 1
if a_class is None or b_class is None:
return 4
elif a_class == b_class:
return 0
else:
return 2 * (3 - self._get_class_diff(a_class, b_class))
def _get_pos_cost(self, a_pos, b_pos):
"""
计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性
:param b_pos: 单词b的词性
:return: 替换编辑代价
"""
if a_pos == b_pos:
return 0
elif a_pos in self._open_pos and b_pos in self._open_pos:
return 0.25
else:
return 0.499
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
"""
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost
"""
if not (check_all_chinese(a) and check_all_chinese(b)):
return 0.5
if len(a) > len(b):
a, b = b, a
pinyin_a, pinyin_b = pinyin_b, pinyin_a
if a == b:
return 0
else:
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
"""
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a
:param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音
:return: 替换操作cost
"""
count = 0
for i in range(len(a)):
for j in range(len(b)):
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
count += 1
break
return (len(a) - count) / (len(a) * 2)
def get_sub_cost(self, a_seg, b_seg):
"""
Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
"""
if a_seg[0] == b_seg[0]:
return 0
if self.granularity == "word": # 词级别可以额外利用词性信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + pos_cost + char_cost
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
if a_seg[0] in punct and b_seg[0] in punct:
pos_cost = 0.0
elif a_seg[0] not in punct and b_seg[0] not in punct:
pos_cost = 0.25
else:
pos_cost = 0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + char_cost + pos_cost
def align(self,
src: List[Tuple],
tgt: List[Tuple]):
"""
Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题
"""
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
oper_matrix = np.full(
(len(src) + 1, len(tgt) + 1), "O", dtype=object
) # 操作矩阵
# Fill in the edges
for i in range(1, len(src) + 1):
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
oper_matrix[i][0] = ["D"]
for j in range(1, len(tgt) + 1):
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
oper_matrix[0][j] = ["I"]
# Loop through the cost matrix
for i in range(len(src)):
for j in range(len(tgt)):
# Matches
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
oper_matrix[i + 1][j + 1] = ["M"]
# Non-matches
else:
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
src[i], tgt[j]
) # 由替换动作得到的总cost
# Calculate transposition cost
# 计算移位操作的总cost
trans_cost = float("inf")
k = 1
while (
i - k >= 0
and j - k >= 0
and cost_matrix[i - k + 1][j - k + 1]
!= cost_matrix[i - k][j - k]
):
p1 = sorted([a[0] for a in src][i - k: i + 1])
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
if p1 == p2:
trans_cost = cost_matrix[i - k][j - k] + k
break
k += 1
costs = [trans_cost, sub_cost, ins_cost, del_cost]
ind = costs.index(min(costs))
cost_matrix[i + 1][j + 1] = costs[ind]
# ind = costs.index(costs[ind], ind+1)
for idx, cost in enumerate(costs):
if cost == costs[ind]:
if idx == 0:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
else:
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
elif idx == 1:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["S"]
else:
oper_matrix[i + 1][j + 1].append("S")
elif idx == 2:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["I"]
else:
oper_matrix[i + 1][j + 1].append("I")
else:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["D"]
else:
oper_matrix[i + 1][j + 1].append("D")
return cost_matrix, oper_matrix
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
"""
深度优先遍历,获取最小编辑距离相同的所有序列
"""
if i + j == 0:
self.align_seqs.append(align_seq_now)
else:
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if strategy != "all": ops = ops[:1]
for op in ops:
if op in {"M", "S"}:
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
elif op == "D":
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
elif op == "I":
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
else:
k = int(op[1:])
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
def get_cheapest_align_seq(self, oper_matrix):
"""
回溯获得编辑距离最小的编辑序列
"""
self.align_seqs = []
i = oper_matrix.shape[0] - 1
j = oper_matrix.shape[1] - 1
if abs(i - j) > 10:
self._dfs(i, j , [], oper_matrix, "first")
else:
self._dfs(i, j , [], oper_matrix, "all")
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
return final_align_seqs
if __name__ == "__main__":
tokenizer = Tokenizer("word")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
src, tgt = tokenizer(sents)
import numpy as np
from typing import List, Tuple, Dict
from modules.tokenizer import Tokenizer
import os
from string import punctuation
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct = punctuation
punct = chinese_punct + english_punct
def check_all_chinese(word):
"""
判断一个单词是否全部由中文组成
:param word:
:return:
"""
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
def read_cilin():
"""
Cilin 詞林 is a thesaurus with semantic information
"""
# TODO -- fix this path
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
lines = open(os.path.join(project_dir, "data", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
semantic_dict = {}
semantic_classes = {}
for line in lines:
code, *words = line.split(" ")
for word in words:
semantic_dict[word] = code
# make reverse dict
if code in semantic_classes:
semantic_classes[code] += words
else:
semantic_classes[code] = words
return semantic_dict, semantic_classes
def read_confusion():
confusion_dict = {}
project_dir = os.path.dirname(os.path.dirname(__file__)) # ymliu@2023.5.30 fix the path
with open(os.path.join(project_dir, "data", "confusion_dict.txt"), "r", encoding="utf-8") as f:
for line in f:
li = line.rstrip('\n').split(" ")
confusion_dict[li[0]] = li[1:]
return confusion_dict
class Alignment:
"""
对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作
"""
def __init__(
self,
semantic_dict: Dict,
confusion_dict: Dict,
granularity: str = "word",
) -> None:
"""
构造函数
:param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集
"""
self.insertion_cost = 1
self.deletion_cost = 1
self.semantic_dict = semantic_dict
self.confusion_dict = confusion_dict
# Because we use character level tokenization, this doesn't currently use POS
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
self.granularity = granularity # word-level or character-level
self.align_seqs = []
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
verbose: bool = False):
cost_matrix, oper_matrix = self.align(src, tgt)
align_seq = self.get_cheapest_align_seq(oper_matrix)
if verbose:
print("========== Seg. and POS: ==========")
print(src)
print(tgt)
print("========== Cost Matrix ==========")
print(cost_matrix)
print("========== Oper Matrix ==========")
print(oper_matrix)
print("========== Alignment ==========")
print(align_seq)
print("========== Results ==========")
for a in align_seq:
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
return align_seq
def _get_semantic_class(self, word):
"""
NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别)
"""
if word in self.semantic_dict:
code = self.semantic_dict[word]
high, mid, low = code[0], code[1], code[2:4]
return high, mid, low
else: # unknown
return None
@staticmethod
def _get_class_diff(a_class, b_class):
"""
d == 3 for equivalent semantics
d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距
"""
d = sum([a == b for a, b in zip(a_class, b_class)])
return d
def _get_semantic_cost(self, a, b):
"""
计算基于语义信息的替换操作cost
:param a: 单词a的语义类别
:param b: 单词b的语义类别
:return: 替换编辑代价
"""
a_class = self._get_semantic_class(a)
b_class = self._get_semantic_class(b)
# unknown class, default to 1
if a_class is None or b_class is None:
return 4
elif a_class == b_class:
return 0
else:
return 2 * (3 - self._get_class_diff(a_class, b_class))
def _get_pos_cost(self, a_pos, b_pos):
"""
计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性
:param b_pos: 单词b的词性
:return: 替换编辑代价
"""
if a_pos == b_pos:
return 0
elif a_pos in self._open_pos and b_pos in self._open_pos:
return 0.25
else:
return 0.499
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
"""
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost
"""
if not (check_all_chinese(a) and check_all_chinese(b)):
return 0.5
if len(a) > len(b):
a, b = b, a
pinyin_a, pinyin_b = pinyin_b, pinyin_a
if a == b:
return 0
else:
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
"""
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a
:param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音
:return: 替换操作cost
"""
count = 0
for i in range(len(a)):
for j in range(len(b)):
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
count += 1
break
return (len(a) - count) / (len(a) * 2)
def get_sub_cost(self, a_seg, b_seg):
"""
Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
"""
if a_seg[0] == b_seg[0]:
return 0
if self.granularity == "word": # 词级别可以额外利用词性信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + pos_cost + char_cost
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
if a_seg[0] in punct and b_seg[0] in punct:
pos_cost = 0.0
elif a_seg[0] not in punct and b_seg[0] not in punct:
pos_cost = 0.25
else:
pos_cost = 0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + char_cost + pos_cost
def align(self,
src: List[Tuple],
tgt: List[Tuple]):
"""
Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题
"""
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
oper_matrix = np.full(
(len(src) + 1, len(tgt) + 1), "O", dtype=object
) # 操作矩阵
# Fill in the edges
for i in range(1, len(src) + 1):
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
oper_matrix[i][0] = ["D"]
for j in range(1, len(tgt) + 1):
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
oper_matrix[0][j] = ["I"]
# Loop through the cost matrix
for i in range(len(src)):
for j in range(len(tgt)):
# Matches
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
oper_matrix[i + 1][j + 1] = ["M"]
# Non-matches
else:
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
src[i], tgt[j]
) # 由替换动作得到的总cost
# Calculate transposition cost
# 计算移位操作的总cost
trans_cost = float("inf")
k = 1
while (
i - k >= 0
and j - k >= 0
and cost_matrix[i - k + 1][j - k + 1]
!= cost_matrix[i - k][j - k]
):
p1 = sorted([a[0] for a in src][i - k: i + 1])
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
if p1 == p2:
trans_cost = cost_matrix[i - k][j - k] + k
break
k += 1
costs = [trans_cost, sub_cost, ins_cost, del_cost]
ind = costs.index(min(costs))
cost_matrix[i + 1][j + 1] = costs[ind]
# ind = costs.index(costs[ind], ind+1)
for idx, cost in enumerate(costs):
if cost == costs[ind]:
if idx == 0:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
else:
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
elif idx == 1:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["S"]
else:
oper_matrix[i + 1][j + 1].append("S")
elif idx == 2:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["I"]
else:
oper_matrix[i + 1][j + 1].append("I")
else:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["D"]
else:
oper_matrix[i + 1][j + 1].append("D")
return cost_matrix, oper_matrix
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
"""
深度优先遍历,获取最小编辑距离相同的所有序列
"""
if i + j == 0:
self.align_seqs.append(align_seq_now)
else:
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if strategy != "all": ops = ops[:1]
for op in ops:
if op in {"M", "S"}:
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
elif op == "D":
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
elif op == "I":
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
else:
k = int(op[1:])
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
def get_cheapest_align_seq(self, oper_matrix):
"""
回溯获得编辑距离最小的编辑序列
"""
self.align_seqs = []
i = oper_matrix.shape[0] - 1
j = oper_matrix.shape[1] - 1
if abs(i - j) > 10:
self._dfs(i, j , [], oper_matrix, "first")
else:
self._dfs(i, j , [], oper_matrix, "all")
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
return final_align_seqs
if __name__ == "__main__":
tokenizer = Tokenizer("word")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
src, tgt = tokenizer(sents)
alignment(src, tgt, verbose=True)
\ No newline at end of file
from typing import List, Tuple
from modules.alignment import read_cilin, read_confusion, Alignment
from modules.merger import Merger
from modules.classifier import Classifier
class Annotator:
def __init__(self,
align: Alignment,
merger: Merger,
classifier: Classifier,
granularity: str = "word",
strategy: str = "first"):
self.align = align
self.merger = merger
self.classifier = classifier
self.granularity = granularity
self.strategy = strategy
@classmethod
def create_default(cls, granularity: str = "word", strategy: str = "first"):
"""
Default parameters used in the paper
"""
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
align = Alignment(semantic_dict, confusion_dict, granularity)
merger = Merger(granularity)
classifier = Classifier(granularity)
return cls(align, merger, classifier, granularity, strategy)
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
annotator_id: int = 0,
verbose: bool = False):
"""
Align sentences and annotate them with error type information
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
src_str = "".join(src_tokens)
tgt_str = "".join(tgt_tokens)
# convert to text form
annotations_out = ["S " + " ".join(src_tokens) + "\n"]
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
annotations_out.append(f"T{annotator_id} 没有错误\n")
cors = [tgt_str]
op, toks, inds = "noop", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
elif tgt_str == "无法标注": # Not Annotatable Case
annotations_out.append(f"T{annotator_id} 无法标注\n")
cors = [tgt_str]
op, toks, inds = "NA", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
else: # Other
align_objs = self.align(src, tgt)
edit_objs = []
align_idx = 0
if self.strategy == "first":
align_objs = align_objs[:1]
for align_obj in align_objs:
edits = self.merger(align_obj, src, tgt, verbose)
if edits not in edit_objs:
edit_objs.append(edits)
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
align_idx += 1
cors = self.classifier(src, tgt, edits, verbose)
# annotations_out = []
for cor in cors:
op, toks, inds = cor.op, cor.toks, cor.inds
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
annotations_out.append("\n")
return annotations_out, cors
from typing import List, Tuple
from modules.alignment import read_cilin, read_confusion, Alignment
from modules.merger import Merger
from modules.classifier import Classifier
class Annotator:
def __init__(self,
align: Alignment,
merger: Merger,
classifier: Classifier,
granularity: str = "word",
strategy: str = "first"):
self.align = align
self.merger = merger
self.classifier = classifier
self.granularity = granularity
self.strategy = strategy
@classmethod
def create_default(cls, granularity: str = "word", strategy: str = "first"):
"""
Default parameters used in the paper
"""
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
align = Alignment(semantic_dict, confusion_dict, granularity)
merger = Merger(granularity)
classifier = Classifier(granularity)
return cls(align, merger, classifier, granularity, strategy)
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
annotator_id: int = 0,
verbose: bool = False):
"""
Align sentences and annotate them with error type information
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
src_str = "".join(src_tokens)
tgt_str = "".join(tgt_tokens)
# convert to text form
annotations_out = ["S " + " ".join(src_tokens) + "\n"]
if tgt_str == "没有错误" or src_str == tgt_str: # Error Free Case
annotations_out.append(f"T{annotator_id} 没有错误\n")
cors = [tgt_str]
op, toks, inds = "noop", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
elif tgt_str == "无法标注": # Not Annotatable Case
annotations_out.append(f"T{annotator_id} 无法标注\n")
cors = [tgt_str]
op, toks, inds = "NA", "-NONE-", (-1, -1)
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
else: # Other
align_objs = self.align(src, tgt)
edit_objs = []
align_idx = 0
if self.strategy == "first":
align_objs = align_objs[:1]
for align_obj in align_objs:
edits = self.merger(align_obj, src, tgt, verbose)
if edits not in edit_objs:
edit_objs.append(edits)
annotations_out.append(f"T{annotator_id}-A{align_idx} " + " ".join(tgt_tokens) + "\n")
align_idx += 1
cors = self.classifier(src, tgt, edits, verbose)
# annotations_out = []
for cor in cors:
op, toks, inds = cor.op, cor.toks, cor.inds
a_str = f"A {inds[0]} {inds[1]}|||{op}|||{toks}|||REQUIRED|||-NONE-|||{annotator_id}\n"
annotations_out.append(a_str)
annotations_out.append("\n")
return annotations_out, cors
from char_smi import CharFuncs
from collections import namedtuple
from pypinyin import pinyin, Style
import os
Correction = namedtuple(
"Correction",
[
"op",
"toks",
"inds",
],
)
file_path = os.path.dirname(os.path.abspath(__file__))
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
def check_spell_error(src_span: str,
tgt_span: str,
threshold: float = 0.8) -> bool:
if len(src_span) != len(tgt_span):
return False
src_chars = [ch for ch in src_span]
tgt_chars = [ch for ch in tgt_span]
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
return True
for src_char, tgt_char in zip(src_chars, tgt_chars):
if src_char != tgt_char:
if src_char not in char_smi.data or tgt_char not in char_smi.data:
return False
v_sim = char_smi.shape_similarity(src_char, tgt_char)
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
if v_sim + p_sim < threshold and not (
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
return False
return True
class Classifier:
"""
错误类型分类器
"""
def __init__(self,
granularity: str = "word"):
self.granularity = granularity
@staticmethod
def get_pos_type(pos):
if pos in {"n", "nd"}:
return "NOUN"
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
return "NOUN-NE"
if pos in {"v"}:
return "VERB"
if pos in {"a", "b"}:
return "ADJ"
if pos in {"c"}:
return "CONJ"
if pos in {"r"}:
return "PRON"
if pos in {"d"}:
return "ADV"
if pos in {"u"}:
return "AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX"
if pos in {"m"}:
return "NUM"
if pos in {"p"}:
return "PREP"
if pos in {"q"}:
return "QUAN"
if pos in {"wp"}:
return "PUNCT"
return "OTHER"
def __call__(self,
src,
tgt,
edits,
verbose: bool = False):
"""
为编辑操作划分错误类型
:param src: 错误句子信息
:param tgt: 正确句子信息
:param edits: 编辑操作
:param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作
"""
results = []
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
for edit in edits:
error_type = edit[0]
src_span = " ".join(src_tokens[edit[1]: edit[2]])
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
# print(tgt_span)
cor = None
if error_type[0] == "T":
cor = Correction("W", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "D":
if self.granularity == "word": # 词级别可以细分错误类型
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
else:
pos = self.get_pos_type(src[edit[1]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("R", "-NONE-", (edit[1], edit[2]))
elif error_type[0] == "I":
if self.granularity == "word": # 词级别可以细分错误类型
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
else:
pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("M", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "S":
if self.granularity == "word": # 词级别可以细分错误类型
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
# Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else:
# pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else:
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
else:
pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("S", tgt_span, (edit[1], edit[2]))
results.append(cor)
if verbose:
print("========== Corrections ==========")
for cor in results:
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
return results
# print(pinyin("朝", style=Style.NORMAL))
from char_smi import CharFuncs
from collections import namedtuple
from pypinyin import pinyin, Style
import os
Correction = namedtuple(
"Correction",
[
"op",
"toks",
"inds",
],
)
file_path = os.path.dirname(os.path.abspath(__file__))
char_smi = CharFuncs(os.path.join(file_path.replace("modules", ""), 'data/char_meta.txt'))
def check_spell_error(src_span: str,
tgt_span: str,
threshold: float = 0.8) -> bool:
if len(src_span) != len(tgt_span):
return False
src_chars = [ch for ch in src_span]
tgt_chars = [ch for ch in tgt_span]
if sorted(src_chars) == sorted(tgt_chars): # 词内部字符异位
return True
for src_char, tgt_char in zip(src_chars, tgt_chars):
if src_char != tgt_char:
if src_char not in char_smi.data or tgt_char not in char_smi.data:
return False
v_sim = char_smi.shape_similarity(src_char, tgt_char)
p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
if v_sim + p_sim < threshold and not (
set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
return False
return True
class Classifier:
"""
错误类型分类器
"""
def __init__(self,
granularity: str = "word"):
self.granularity = granularity
@staticmethod
def get_pos_type(pos):
if pos in {"n", "nd"}:
return "NOUN"
if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
return "NOUN-NE"
if pos in {"v"}:
return "VERB"
if pos in {"a", "b"}:
return "ADJ"
if pos in {"c"}:
return "CONJ"
if pos in {"r"}:
return "PRON"
if pos in {"d"}:
return "ADV"
if pos in {"u"}:
return "AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX"
if pos in {"m"}:
return "NUM"
if pos in {"p"}:
return "PREP"
if pos in {"q"}:
return "QUAN"
if pos in {"wp"}:
return "PUNCT"
return "OTHER"
def __call__(self,
src,
tgt,
edits,
verbose: bool = False):
"""
为编辑操作划分错误类型
:param src: 错误句子信息
:param tgt: 正确句子信息
:param edits: 编辑操作
:param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作
"""
results = []
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
for edit in edits:
error_type = edit[0]
src_span = " ".join(src_tokens[edit[1]: edit[2]])
tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
# print(tgt_span)
cor = None
if error_type[0] == "T":
cor = Correction("W", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "D":
if self.granularity == "word": # 词级别可以细分错误类型
if edit[2] - edit[1] > 1: # 词组冗余暂时分为OTHER
cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
else:
pos = self.get_pos_type(src[edit[1]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("R", "-NONE-", (edit[1], edit[2]))
elif error_type[0] == "I":
if self.granularity == "word": # 词级别可以细分错误类型
if edit[4] - edit[3] > 1: # 词组丢失暂时分为OTHER
cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
else:
pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("M", tgt_span, (edit[1], edit[2]))
elif error_type[0] == "S":
if self.granularity == "word": # 词级别可以细分错误类型
if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
# Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else:
# pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else:
if edit[4] - edit[3] > 1: # 词组被替换暂时分为OTHER
cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
else:
pos = self.get_pos_type(tgt[edit[3]][1])
pos = "NOUN" if pos == "NOUN-NE" else pos
pos = "MC" if tgt_span == "[缺失成分]" else pos
cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
else: # 字级别可以只需要根据操作划分类型即可
cor = Correction("S", tgt_span, (edit[1], edit[2]))
results.append(cor)
if verbose:
print("========== Corrections ==========")
for cor in results:
print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
return results
# print(pinyin("朝", style=Style.NORMAL))
from itertools import groupby
from string import punctuation
from typing import List
from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein
class Merger:
"""
合并编辑操作,从Token-Level转换为Span-Level
"""
def __init__(self,
granularity: str = "word",
merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity
self.merge = merge
@staticmethod
def _merge_edits(seq, tag="X"):
if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else:
return seq
@staticmethod
def _check_revolve(span_a, span_b):
span_a = span_a + span_a
return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1:
return seq
ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt
return seq
if set(ops) == {"S"}:
if self.granularity == "word":
return seq
else:
return self._merge_edits(seq, "S")
if set(ops) == {"M"}:
return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S")
def __call__(self,
align_obj,
src: List,
tgt: List,
verbose: bool = False):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
edits = []
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby(
align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
):
group = list(group)
# T is always split TODO: Evaluate this
if op == "T":
for seq in group:
edits.append(seq)
# Process D, I and S subsequence
else:
# Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed:
edits.append(seq)
filtered_edits = []
i = 0
while i < len(edits):
e1 = edits[i][0][0]
if i < len(edits) - 2:
e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b):
a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1:
if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2:
if edit[0] == "S":
b = True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str = ""
tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]):
break
common_str += src_tokens[i]
tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3]
if common_str:
tmp_str = ""
for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i]
if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False
break
elif len(tmp_str) > len(common_str):
break
# 尾部有重叠
common_str = ""
new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str):
break
common_str = src_tokens[i] + common_str
tmp_new_end_1 = i
if common_str:
tmp_str = ""
for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i
b = False
break
elif len(tmp_str) > len(common_str):
break
if b:
second_filter.append(edit)
else:
if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit)
else:
second_filter.append(edit)
if verbose:
print("========== Parallels ==========")
print("".join(src_tokens))
print("".join(tgt_tokens))
print("========== Results ==========")
for edit in second_filter:
op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========")
print(str(src))
print(str(tgt))
return second_filter
if __name__ == "__main__":
tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")]
src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt)
m = Merger()
from itertools import groupby
from string import punctuation
from typing import List
from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein
class Merger:
"""
合并编辑操作,从Token-Level转换为Span-Level
"""
def __init__(self,
granularity: str = "word",
merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity
self.merge = merge
@staticmethod
def _merge_edits(seq, tag="X"):
if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else:
return seq
@staticmethod
def _check_revolve(span_a, span_b):
span_a = span_a + span_a
return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1:
return seq
ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt
return seq
if set(ops) == {"S"}:
if self.granularity == "word":
return seq
else:
return self._merge_edits(seq, "S")
if set(ops) == {"M"}:
return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S")
def __call__(self,
align_obj,
src: List,
tgt: List,
verbose: bool = False):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
edits = []
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby(
align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
):
group = list(group)
# T is always split TODO: Evaluate this
if op == "T":
for seq in group:
edits.append(seq)
# Process D, I and S subsequence
else:
# Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed:
edits.append(seq)
filtered_edits = []
i = 0
while i < len(edits):
e1 = edits[i][0][0]
if i < len(edits) - 2:
e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b):
a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1:
if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2:
if edit[0] == "S":
b = True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str = ""
tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]):
break
common_str += src_tokens[i]
tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3]
if common_str:
tmp_str = ""
for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i]
if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False
break
elif len(tmp_str) > len(common_str):
break
# 尾部有重叠
common_str = ""
new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str):
break
common_str = src_tokens[i] + common_str
tmp_new_end_1 = i
if common_str:
tmp_str = ""
for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i
b = False
break
elif len(tmp_str) > len(common_str):
break
if b:
second_filter.append(edit)
else:
if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit)
else:
second_filter.append(edit)
if verbose:
print("========== Parallels ==========")
print("".join(src_tokens))
print("".join(tgt_tokens))
print("========== Results ==========")
for edit in second_filter:
op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========")
print(str(src))
print(str(tgt))
return second_filter
if __name__ == "__main__":
tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")]
src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt)
m = Merger()
m(align_obj, src, tgt, verbose=True)
\ No newline at end of file
from ltp import LTP
from typing import List
from pypinyin import pinyin, Style, lazy_pinyin
import torch
import os
import functools
class Tokenizer:
"""
分词器
"""
def __init__(self,
granularity: str = "word",
device: str = "cpu",
segmented: bool = False,
bpe: bool = False,
) -> None:
"""
构造函数
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
"""
self.ltp = None
if granularity == "word":
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
self.ltp.add_words(words=["[缺失成分]"], max_window=6)
self.segmented = segmented
self.granularity = granularity
if self.granularity == "word":
self.tokenizer = self.split_word
elif self.granularity == "char":
self.tokenizer = functools.partial(self.split_char, bpe=bpe)
else:
raise NotImplementedError
def __repr__(self) -> str:
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
def __call__(self,
input_strings: List[str]
) -> List:
"""
分词函数
:param input_strings: 需要分词的字符串列表
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
"""
if not self.segmented:
input_strings = ["".join(s.split(" ")) for s in input_strings]
results = self.tokenizer(input_strings)
return results
def split_char(self, input_strings: List[str], bpe=False) -> List:
"""
分字函数
:param input_strings: 需要分字的字符串
:return: 分字结果
"""
if bpe:
from . import tokenization
project_dir = os.path.dirname(os.path.dirname(__file__))
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
results = []
for input_string in input_strings:
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
else:
segment_string = input_string
# print(segment_string)
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
return results
def split_word(self, input_strings: List[str]) -> List:
"""
分词函数
:param input_strings: 需要分词的字符串
:return: 分词结果
"""
if self.segmented:
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
else:
seg, hidden = self.ltp.seg(input_strings)
pos = self.ltp.pos(hidden)
result = []
for s, p in zip(seg, pos):
pinyin = [lazy_pinyin(word) for word in s]
result.append(list(zip(s, p, pinyin)))
return result
if __name__ == "__main__":
tokenizer = Tokenizer("word")
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
from ltp import LTP
from typing import List
from pypinyin import pinyin, Style, lazy_pinyin
import torch
import os
import functools
class Tokenizer:
"""
分词器
"""
def __init__(self,
granularity: str = "word",
device: str = "cpu",
segmented: bool = False,
bpe: bool = False,
) -> None:
"""
构造函数
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
"""
self.ltp = None
if granularity == "word":
self.ltp = LTP(device=torch.device(device) if torch.cuda.is_available() else torch.device("cpu"))
self.ltp.add_words(words=["[缺失成分]"], max_window=6)
self.segmented = segmented
self.granularity = granularity
if self.granularity == "word":
self.tokenizer = self.split_word
elif self.granularity == "char":
self.tokenizer = functools.partial(self.split_char, bpe=bpe)
else:
raise NotImplementedError
def __repr__(self) -> str:
return "{:s}\nMode:{:s}\n}".format(str(self.__class__.__name__), self.mode)
def __call__(self,
input_strings: List[str]
) -> List:
"""
分词函数
:param input_strings: 需要分词的字符串列表
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
"""
if not self.segmented:
input_strings = ["".join(s.split(" ")) for s in input_strings]
results = self.tokenizer(input_strings)
return results
def split_char(self, input_strings: List[str], bpe=False) -> List:
"""
分字函数
:param input_strings: 需要分字的字符串
:return: 分字结果
"""
if bpe:
from . import tokenization
project_dir = os.path.dirname(os.path.dirname(__file__))
tokenizer = tokenization.FullTokenizer(vocab_file=os.path.join(project_dir,"data","chinese_vocab.txt"), do_lower_case=False)
results = []
for input_string in input_strings:
if not self.segmented: # 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
segment_string = " ".join([char for char in input_string] if not bpe else tokenizer.tokenize(input_string))
else:
segment_string = input_string
# print(segment_string)
segment_string = segment_string.replace("[ 缺 失 成 分 ]", "[缺失成分]").split(" ") # 缺失成分当成一个单独的token
results.append([(char, "unk", pinyin(char, style=Style.NORMAL, heteronym=True)[0]) for char in segment_string])
return results
def split_word(self, input_strings: List[str]) -> List:
"""
分词函数
:param input_strings: 需要分词的字符串
:return: 分词结果
"""
if self.segmented:
seg, hidden = self.ltp.seg([input_string.split(" ") for input_string in input_strings], is_preseged=True)
else:
seg, hidden = self.ltp.seg(input_strings)
pos = self.ltp.pos(hidden)
result = []
for s, p in zip(seg, pos):
pinyin = [lazy_pinyin(word) for word in s]
result.append(list(zip(s, p, pinyin)))
return result
if __name__ == "__main__":
tokenizer = Tokenizer("word")
print(tokenizer(["LAC是个优秀的分词工具", "百度是一家高科技公司"]))
import os
from modules.annotator import Annotator
from modules.tokenizer import Tokenizer
import argparse
from collections import Counter
from tqdm import tqdm
import torch
from collections import defaultdict
from multiprocessing import Pool
from opencc import OpenCC
import timeout_decorator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
annotator, sentence_to_tokenized = None, None
cc = OpenCC("t2s")
@timeout_decorator.timeout(10)
def annotate_with_time_out(line):
"""
:param line:
:return:
"""
sent_list = line.split("\t")[1:]
source = sent_list[0]
if args.segmented:
source = source.strip()
else:
source = "".join(source.strip().split())
output_str = ""
for idx, target in enumerate(sent_list[1:]):
try:
if args.segmented:
target = target.strip()
else:
target = "".join(target.strip().split())
if not args.no_simplified:
target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0:
output_str += "".join(out[:-1])
else:
output_str += "".join(out[1:-1])
except Exception:
raise Exception
return output_str
def annotate(line):
"""
:param line:
:return:
"""
sent_list = line.split("\t")[1:]
source = sent_list[0]
if args.segmented:
source = source.strip()
else:
source = "".join(source.strip().split())
output_str = ""
for idx, target in enumerate(sent_list[1:]):
try:
if args.segmented:
target = target.strip()
else:
target = "".join(target.strip().split())
if not args.no_simplified:
target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0:
output_str += "".join(out[:-1])
else:
output_str += "".join(out[1:-1])
except Exception:
raise Exception
return output_str
def firsttime_process(args):
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
# error_types = []
with open(args.output, "w", encoding="utf-8") as f:
count = 0
sentence_set = set()
sentence_to_tokenized = {}
for line in lines:
sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list):
if args.segmented:
# print(sent)
sent = sent.strip()
else:
sent = "".join(sent.split()).strip()
if idx >= 1:
if not args.no_simplified:
sentence_set.add(cc.convert(sent))
else:
sentence_set.add(sent)
else:
sentence_set.add(sent)
batch = []
for sent in tqdm(sentence_set):
count += 1
if sent:
batch.append(sent)
if count % args.batch_size == 0:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
batch = []
if batch:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
timeout_indices = []
# 单进程模式
for idx, line in enumerate(tqdm(lines)):
try:
ret = annotate_with_time_out(line)
except Exception:
timeout_indices.append(idx)
return timeout_indices
def main(args):
timeout_indices = firsttime_process(args)
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
new_lines = []# format: id src tgt1 tgt2...
with open(args.output, "w", encoding="utf-8") as f:
count = 0
sentence_set = set()
sentence_to_tokenized = {}
for line_idx, line in enumerate(lines):
if line_idx in timeout_indices:
# print(f"line before split: {line}")
line_split = line.split("\t")
line_number, sent_list = line_split[0], line_split[1:]
assert len(sent_list) == 2
sent_list[-1] = " 无"
line = line_number + "\t" + "\t".join(sent_list)
# print(f"line time out: {line}")
new_lines.append(line)
else:
new_lines.append(line)
sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list):
if args.segmented:
# print(sent)
sent = sent.strip()
else:
sent = "".join(sent.split()).strip()
if idx >= 1:
if not args.no_simplified:
sentence_set.add(cc.convert(sent))
else:
sentence_set.add(sent)
else:
sentence_set.add(sent)
batch = []
for sent in tqdm(sentence_set):
count += 1
if sent:
batch.append(sent)
if count % args.batch_size == 0:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
batch = []
if batch:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
# 单进程模式
lines = new_lines
for idx, line in enumerate(tqdm(lines)):
ret = annotate(line)
f.write(ret)
f.write("\n")
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# with Pool(args.worker_num) as pool:
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# if ret:
# f.write(ret)
# f.write("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Choose input file to annotate")
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
args = parser.parse_args()
main(args)
import os
from modules.annotator import Annotator
from modules.tokenizer import Tokenizer
import argparse
from collections import Counter
from tqdm import tqdm
import torch
from collections import defaultdict
from multiprocessing import Pool
from opencc import OpenCC
import timeout_decorator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
annotator, sentence_to_tokenized = None, None
cc = OpenCC("t2s")
@timeout_decorator.timeout(10)
def annotate_with_time_out(line):
"""
:param line:
:return:
"""
sent_list = line.split("\t")[1:]
source = sent_list[0]
if args.segmented:
source = source.strip()
else:
source = "".join(source.strip().split())
output_str = ""
for idx, target in enumerate(sent_list[1:]):
try:
if args.segmented:
target = target.strip()
else:
target = "".join(target.strip().split())
if not args.no_simplified:
target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0:
output_str += "".join(out[:-1])
else:
output_str += "".join(out[1:-1])
except Exception:
raise Exception
return output_str
def annotate(line):
"""
:param line:
:return:
"""
sent_list = line.split("\t")[1:]
source = sent_list[0]
if args.segmented:
source = source.strip()
else:
source = "".join(source.strip().split())
output_str = ""
for idx, target in enumerate(sent_list[1:]):
try:
if args.segmented:
target = target.strip()
else:
target = "".join(target.strip().split())
if not args.no_simplified:
target = cc.convert(target)
source_tokenized, target_tokenized = sentence_to_tokenized[source], sentence_to_tokenized[target]
out, cors = annotator(source_tokenized, target_tokenized, idx)
if idx == 0:
output_str += "".join(out[:-1])
else:
output_str += "".join(out[1:-1])
except Exception:
raise Exception
return output_str
def firsttime_process(args):
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n") # format: id src tgt1 tgt2...
# error_types = []
with open(args.output, "w", encoding="utf-8") as f:
count = 0
sentence_set = set()
sentence_to_tokenized = {}
for line in lines:
sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list):
if args.segmented:
# print(sent)
sent = sent.strip()
else:
sent = "".join(sent.split()).strip()
if idx >= 1:
if not args.no_simplified:
sentence_set.add(cc.convert(sent))
else:
sentence_set.add(sent)
else:
sentence_set.add(sent)
batch = []
for sent in tqdm(sentence_set):
count += 1
if sent:
batch.append(sent)
if count % args.batch_size == 0:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
batch = []
if batch:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
timeout_indices = []
# 单进程模式
for idx, line in enumerate(tqdm(lines)):
try:
ret = annotate_with_time_out(line)
except Exception:
timeout_indices.append(idx)
return timeout_indices
def main(args):
timeout_indices = firsttime_process(args)
tokenizer = Tokenizer(args.granularity, args.device, args.segmented, args.bpe)
global annotator, sentence_to_tokenized
annotator = Annotator.create_default(args.granularity, args.multi_cheapest_strategy)
lines = open(args.file, "r", encoding="utf-8").read().strip().split("\n")
new_lines = []# format: id src tgt1 tgt2...
with open(args.output, "w", encoding="utf-8") as f:
count = 0
sentence_set = set()
sentence_to_tokenized = {}
for line_idx, line in enumerate(lines):
if line_idx in timeout_indices:
# print(f"line before split: {line}")
line_split = line.split("\t")
line_number, sent_list = line_split[0], line_split[1:]
assert len(sent_list) == 2
sent_list[-1] = " 无"
line = line_number + "\t" + "\t".join(sent_list)
# print(f"line time out: {line}")
new_lines.append(line)
else:
new_lines.append(line)
sent_list = line.split("\t")[1:]
for idx, sent in enumerate(sent_list):
if args.segmented:
# print(sent)
sent = sent.strip()
else:
sent = "".join(sent.split()).strip()
if idx >= 1:
if not args.no_simplified:
sentence_set.add(cc.convert(sent))
else:
sentence_set.add(sent)
else:
sentence_set.add(sent)
batch = []
for sent in tqdm(sentence_set):
count += 1
if sent:
batch.append(sent)
if count % args.batch_size == 0:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
batch = []
if batch:
results = tokenizer(batch)
for s, r in zip(batch, results):
sentence_to_tokenized[s] = r # Get tokenization map.
# 单进程模式
lines = new_lines
for idx, line in enumerate(tqdm(lines)):
ret = annotate(line)
f.write(ret)
f.write("\n")
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# with Pool(args.worker_num) as pool:
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# if ret:
# f.write(ret)
# f.write("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Choose input file to annotate")
parser.add_argument("-f", "--file", type=str, required=True, help="Input parallel file")
parser.add_argument("-o", "--output", type=str, help="Output file", required=True)
parser.add_argument("-b", "--batch_size", type=int, help="The size of batch", default=128)
parser.add_argument("-d", "--device", type=int, help="The ID of GPU", default=0)
parser.add_argument("-w", "--worker_num", type=int, help="The number of workers", default=16)
parser.add_argument("-g", "--granularity", type=str, help="Choose char-level or word-level evaluation", default="char")
parser.add_argument("-m", "--merge", help="Whether merge continuous replacement/deletion/insertion", action="store_true")
parser.add_argument("-s", "--multi_cheapest_strategy", type=str, choices=["first", "all"], default="all")
parser.add_argument("--segmented", help="Whether tokens have been segmented", action="store_true") # 支持提前token化,用空格隔开
parser.add_argument("--no_simplified", help="Whether simplifying chinese", action="store_true") # 将所有corrections转换为简体中文
parser.add_argument("--bpe", help="Whether to use bpe", action="store_true") # 支持 bpe 切分英文单词
args = parser.parse_args()
main(args)
import datetime
import os
import os.path as osp
import random
import re
import subprocess
import sys
import time
from functools import partial
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import mmengine
from mmengine.config import ConfigDict
......@@ -43,6 +46,11 @@ class DLCRunner(BaseRunner):
self.max_num_workers = max_num_workers
self.retry = retry
logger = get_logger()
logger.warning(
'To ensure the integrity of the log results, the log displayed '
f'by {self.__class__.__name__} has a 10-second delay.')
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
"""Launch multiple tasks.
......@@ -63,18 +71,23 @@ class DLCRunner(BaseRunner):
status = [self._launch(task, random_sleep=False) for task in tasks]
return status
def _launch(self, cfg: ConfigDict, random_sleep: bool = True):
def _launch(self, cfg: ConfigDict, random_sleep: Optional[bool] = None):
"""Launch a single task.
Args:
cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
multiple tasks at the same time. Default: True.
running the command. When Aliyun has many tasks to schedule,
its stability decreases. Therefore, when we need to submit a
large number of tasks at once, we adopt the "random_sleep"
strategy. Tasks that would have been submitted all at once are
now evenly spread out over a 10-second period. Default: None.
Returns:
tuple[str, int]: Task name and exit code.
"""
if random_sleep is None:
random_sleep = (self.max_num_workers > 32)
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type']))
num_gpus = task.num_gpus
......@@ -116,7 +129,7 @@ class DLCRunner(BaseRunner):
# Run command with retry
if self.debug:
stdout = None
stdout = sys.stdout
else:
out_path = task.get_log_path(file_extension='out')
mmengine.mkdir_or_exist(osp.split(out_path)[0])
......@@ -124,30 +137,92 @@ class DLCRunner(BaseRunner):
if random_sleep:
time.sleep(random.randint(0, 10))
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
def _run_within_retry():
try:
process = subprocess.Popen(cmd,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
job_id = None
job_allocated = False
job_finished = False
last_end_time = datetime.datetime.now().strftime(
'%Y-%m-%dT%H:%M:%SZ')
while True:
if not job_allocated:
line = process.stdout.readline()
if not line:
break
match = re.search(r'(dlc[0-9a-z]+)', line)
if match and job_id is None:
job_id = match.group(1)
stdout.write(line)
match = re.search(r'Job .* is \[Running\]', line)
if match:
job_allocated = True
else:
try:
process.wait(10)
except subprocess.TimeoutExpired:
pass
else:
job_finished = True
if job_finished:
this_end_time = datetime.datetime.now(
).strftime('%Y-%m-%dT%H:%M:%SZ')
else:
this_end_time = (
datetime.datetime.now() -
datetime.timedelta(seconds=10)
).strftime('%Y-%m-%dT%H:%M:%SZ')
logs_cmd = (
'dlc logs'
f' {job_id} {job_id}-worker-0'
f' --start_time {last_end_time}'
f' --end_time {this_end_time}'
f" -c {self.aliyun_cfg['dlc_config_path']}")
log_process = subprocess.Popen(
logs_cmd,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
log_output, log_err = log_process.communicate()
log_output = '\n'.join(log_output.split('\n')[2:])
stdout.write(log_output)
last_end_time = this_end_time
stdout.flush()
if job_finished:
break
process.wait()
return process.returncode
finally:
if job_id is not None:
cancel_cmd = (
'dlc stop job'
f' {job_id}'
f" -c {self.aliyun_cfg['dlc_config_path']}"
' -f')
subprocess.run(cancel_cmd,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return_code = _run_within_retry()
retry = self.retry
output_paths = task.get_output_paths()
while self._job_failed(result.returncode,
output_paths) and retry > 0:
while self._job_failed(return_code, output_paths) and retry > 0:
retry -= 1
if random_sleep:
time.sleep(random.randint(0, 10))
# Re-generate command to refresh ports.
cmd = get_cmd()
result = subprocess.run(cmd,
shell=True,
text=True,
stdout=stdout,
stderr=stdout)
return_code = _run_within_retry()
finally:
# Clean up
os.remove(param_file)
return task_name, result.returncode
return task_name, return_code
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool:
return return_code != 0 or not all(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment