Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
4dd9a3fc
Unverified
Commit
4dd9a3fc
authored
Oct 18, 2023
by
Leymore
Committed by
GitHub
Oct 18, 2023
Browse files
[Sync] sync with internal codes 20231019 (#488)
parent
2737249f
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1723 additions
and
1648 deletions
+1723
-1648
opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py
...pass/datasets/lawbench/utils/compare_m2_for_evaluation.py
+433
-433
opencompass/datasets/lawbench/utils/function_utils.py
opencompass/datasets/lawbench/utils/function_utils.py
+49
-49
opencompass/datasets/lawbench/utils/modules/alignment.py
opencompass/datasets/lawbench/utils/modules/alignment.py
+333
-333
opencompass/datasets/lawbench/utils/modules/annotator.py
opencompass/datasets/lawbench/utils/modules/annotator.py
+76
-76
opencompass/datasets/lawbench/utils/modules/classifier.py
opencompass/datasets/lawbench/utils/modules/classifier.py
+151
-151
opencompass/datasets/lawbench/utils/modules/merger.py
opencompass/datasets/lawbench/utils/modules/merger.py
+272
-272
opencompass/datasets/lawbench/utils/modules/tokenizer.py
opencompass/datasets/lawbench/utils/modules/tokenizer.py
+92
-92
opencompass/datasets/lawbench/utils/parallel_to_m2.py
opencompass/datasets/lawbench/utils/parallel_to_m2.py
+221
-221
opencompass/runners/dlc.py
opencompass/runners/dlc.py
+96
-21
No files found.
opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py
View file @
4dd9a3fc
import
argparse
import
argparse
from
collections
import
Counter
from
collections
import
Counter
def
main
():
def
main
():
# Parse command line args
# Parse command line args
args
=
parse_args
()
args
=
parse_args
()
# Open hypothesis and reference m2 files and split into chunks
# Open hypothesis and reference m2 files and split into chunks
hyp_m2
=
open
(
args
.
hyp
).
read
().
strip
().
split
(
"
\n\n
"
)[
args
.
start
:
args
.
end
]
if
args
.
start
is
not
None
and
args
.
end
is
not
None
else
open
(
args
.
hyp
).
read
().
strip
().
split
(
"
\n\n
"
)
hyp_m2
=
open
(
args
.
hyp
).
read
().
strip
().
split
(
"
\n\n
"
)[
args
.
start
:
args
.
end
]
if
args
.
start
is
not
None
and
args
.
end
is
not
None
else
open
(
args
.
hyp
).
read
().
strip
().
split
(
"
\n\n
"
)
ref_m2
=
open
(
args
.
ref
).
read
().
strip
().
split
(
"
\n\n
"
)[
args
.
start
:
args
.
end
]
if
args
.
start
is
not
None
and
args
.
end
is
not
None
else
open
(
args
.
ref
).
read
().
strip
().
split
(
"
\n\n
"
)
ref_m2
=
open
(
args
.
ref
).
read
().
strip
().
split
(
"
\n\n
"
)[
args
.
start
:
args
.
end
]
if
args
.
start
is
not
None
and
args
.
end
is
not
None
else
open
(
args
.
ref
).
read
().
strip
().
split
(
"
\n\n
"
)
# Make sure they have the same number of sentences
# Make sure they have the same number of sentences
assert
len
(
hyp_m2
)
==
len
(
ref_m2
),
print
(
len
(
hyp_m2
),
len
(
ref_m2
))
assert
len
(
hyp_m2
)
==
len
(
ref_m2
),
print
(
len
(
hyp_m2
),
len
(
ref_m2
))
# Store global corpus level best counts here
# Store global corpus level best counts here
best_dict
=
Counter
({
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
})
best_dict
=
Counter
({
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
})
best_cats
=
{}
best_cats
=
{}
# Process each sentence
# Process each sentence
sents
=
zip
(
hyp_m2
,
ref_m2
)
sents
=
zip
(
hyp_m2
,
ref_m2
)
for
sent_id
,
sent
in
enumerate
(
sents
):
for
sent_id
,
sent
in
enumerate
(
sents
):
# Simplify the edits into lists of lists
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
# sent_id_cons.append(sent_id)
src
=
sent
[
0
].
split
(
"
\n
"
)[
0
]
src
=
sent
[
0
].
split
(
"
\n
"
)[
0
]
hyp_edits
=
simplify_edits
(
sent
[
0
],
args
.
max_answer_num
)
hyp_edits
=
simplify_edits
(
sent
[
0
],
args
.
max_answer_num
)
ref_edits
=
simplify_edits
(
sent
[
1
],
args
.
max_answer_num
)
ref_edits
=
simplify_edits
(
sent
[
1
],
args
.
max_answer_num
)
# Process the edits for detection/correction based on args
# Process the edits for detection/correction based on args
hyp_dict
=
process_edits
(
hyp_edits
,
args
)
hyp_dict
=
process_edits
(
hyp_edits
,
args
)
ref_dict
=
process_edits
(
ref_edits
,
args
)
ref_dict
=
process_edits
(
ref_edits
,
args
)
if
args
.
reference_num
is
None
or
len
(
ref_dict
.
keys
())
==
args
.
reference_num
:
if
args
.
reference_num
is
None
or
len
(
ref_dict
.
keys
())
==
args
.
reference_num
:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict
,
cat_dict
=
evaluate_edits
(
src
,
count_dict
,
cat_dict
=
evaluate_edits
(
src
,
hyp_dict
,
ref_dict
,
best_dict
,
sent_id
,
args
)
hyp_dict
,
ref_dict
,
best_dict
,
sent_id
,
args
)
# Merge these dicts with best_dict and best_cats
# Merge these dicts with best_dict and best_cats
best_dict
+=
Counter
(
count_dict
)
best_dict
+=
Counter
(
count_dict
)
best_cats
=
merge_dict
(
best_cats
,
cat_dict
)
best_cats
=
merge_dict
(
best_cats
,
cat_dict
)
# Print results
# Print results
print_results
(
best_dict
,
best_cats
,
args
)
print_results
(
best_dict
,
best_cats
,
args
)
# Parse command line args
# Parse command line args
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Calculate F-scores for error detection and/or correction.
\n
"
description
=
"Calculate F-scores for error detection and/or correction.
\n
"
"Flags let you evaluate at different levels of granularity."
,
"Flags let you evaluate at different levels of granularity."
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
usage
=
"%(prog)s [options] -hyp HYP -ref REF"
)
usage
=
"%(prog)s [options] -hyp HYP -ref REF"
)
parser
.
add_argument
(
parser
.
add_argument
(
"-hyp"
,
"-hyp"
,
help
=
"A hypothesis M2 file."
,
help
=
"A hypothesis M2 file."
,
required
=
True
)
required
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
"-ref"
,
"-ref"
,
help
=
"A reference M2 file."
,
help
=
"A reference M2 file."
,
required
=
True
)
required
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
"--start"
,
"--start"
,
type
=
int
,
type
=
int
,
default
=
None
default
=
None
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--end"
,
"--end"
,
type
=
int
,
type
=
int
,
default
=
None
default
=
None
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_answer_num"
,
"--max_answer_num"
,
type
=
int
,
type
=
int
,
default
=
None
default
=
None
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--reference_num"
,
"--reference_num"
,
type
=
int
,
type
=
int
,
default
=
None
default
=
None
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"-b"
,
"-b"
,
"--beta"
,
"--beta"
,
help
=
"Value of beta in F-score. (default: 0.5)"
,
help
=
"Value of beta in F-score. (default: 0.5)"
,
default
=
0.5
,
default
=
0.5
,
type
=
float
)
type
=
float
)
parser
.
add_argument
(
parser
.
add_argument
(
"-v"
,
"-v"
,
"--verbose"
,
"--verbose"
,
help
=
"Print verbose output."
,
help
=
"Print verbose output."
,
action
=
"store_true"
)
action
=
"store_true"
)
eval_type
=
parser
.
add_mutually_exclusive_group
()
eval_type
=
parser
.
add_mutually_exclusive_group
()
eval_type
.
add_argument
(
eval_type
.
add_argument
(
"-dt"
,
"-dt"
,
help
=
"Evaluate Detection in terms of Tokens."
,
help
=
"Evaluate Detection in terms of Tokens."
,
action
=
"store_true"
)
action
=
"store_true"
)
eval_type
.
add_argument
(
eval_type
.
add_argument
(
"-ds"
,
"-ds"
,
help
=
"Evaluate Detection in terms of Spans."
,
help
=
"Evaluate Detection in terms of Spans."
,
action
=
"store_true"
)
action
=
"store_true"
)
eval_type
.
add_argument
(
eval_type
.
add_argument
(
"-cs"
,
"-cs"
,
help
=
"Evaluate Correction in terms of Spans. (default)"
,
help
=
"Evaluate Correction in terms of Spans. (default)"
,
action
=
"store_true"
)
action
=
"store_true"
)
eval_type
.
add_argument
(
eval_type
.
add_argument
(
"-cse"
,
"-cse"
,
help
=
"Evaluate Correction in terms of Spans and Error types."
,
help
=
"Evaluate Correction in terms of Spans and Error types."
,
action
=
"store_true"
)
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"-single"
,
"-single"
,
help
=
"Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1"
,
help
=
"Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1"
,
action
=
"store_true"
)
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"-multi"
,
"-multi"
,
help
=
"Only evaluate multi token edits; i.e. 2+:n or n:2+"
,
help
=
"Only evaluate multi token edits; i.e. 2+:n or n:2+"
,
action
=
"store_true"
)
action
=
"store_true"
)
parser
.
add_argument
(
parser
.
add_argument
(
"-multi_hyp_avg"
,
"-multi_hyp_avg"
,
help
=
"When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence."
,
help
=
"When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence."
,
action
=
"store_true"
)
# For IAA calculation
action
=
"store_true"
)
# For IAA calculation
parser
.
add_argument
(
parser
.
add_argument
(
"-multi_hyp_max"
,
"-multi_hyp_max"
,
help
=
"When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence."
,
help
=
"When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence."
,
action
=
"store_true"
)
# For multiple hypotheses system evaluation
action
=
"store_true"
)
# For multiple hypotheses system evaluation
parser
.
add_argument
(
parser
.
add_argument
(
"-filt"
,
"-filt"
,
help
=
"Do not evaluate the specified error types."
,
help
=
"Do not evaluate the specified error types."
,
nargs
=
"+"
,
nargs
=
"+"
,
default
=
[])
default
=
[])
parser
.
add_argument
(
parser
.
add_argument
(
"-cat"
,
"-cat"
,
help
=
"Show error category scores.
\n
"
help
=
"Show error category scores.
\n
"
"1: Only show operation tier scores; e.g. R.
\n
"
"1: Only show operation tier scores; e.g. R.
\n
"
"2: Only show main tier scores; e.g. NOUN.
\n
"
"2: Only show main tier scores; e.g. NOUN.
\n
"
"3: Show all category scores; e.g. R:NOUN."
,
"3: Show all category scores; e.g. R:NOUN."
,
choices
=
[
1
,
2
,
3
],
choices
=
[
1
,
2
,
3
],
type
=
int
)
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
# Input: An m2 format sentence with edits.
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def
simplify_edits
(
sent
,
max_answer_num
):
def
simplify_edits
(
sent
,
max_answer_num
):
out_edits
=
[]
out_edits
=
[]
# Get the edit lines from an m2 block.
# Get the edit lines from an m2 block.
edits
=
sent
.
split
(
"
\n
"
)
edits
=
sent
.
split
(
"
\n
"
)
# Loop through the edits
# Loop through the edits
for
edit
in
edits
:
for
edit
in
edits
:
# Preprocessing
# Preprocessing
if
edit
.
startswith
(
"A "
):
if
edit
.
startswith
(
"A "
):
edit
=
edit
[
2
:].
split
(
"|||"
)
# Ignore "A " then split.
edit
=
edit
[
2
:].
split
(
"|||"
)
# Ignore "A " then split.
span
=
edit
[
0
].
split
()
span
=
edit
[
0
].
split
()
start
=
int
(
span
[
0
])
start
=
int
(
span
[
0
])
end
=
int
(
span
[
1
])
end
=
int
(
span
[
1
])
cat
=
edit
[
1
]
cat
=
edit
[
1
]
cor
=
edit
[
2
].
replace
(
" "
,
""
)
cor
=
edit
[
2
].
replace
(
" "
,
""
)
coder
=
int
(
edit
[
-
1
])
coder
=
int
(
edit
[
-
1
])
out_edit
=
[
start
,
end
,
cat
,
cor
,
coder
]
out_edit
=
[
start
,
end
,
cat
,
cor
,
coder
]
out_edits
.
append
(
out_edit
)
out_edits
.
append
(
out_edit
)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if
max_answer_num
is
None
:
if
max_answer_num
is
None
:
return
out_edits
return
out_edits
elif
max_answer_num
==
1
:
elif
max_answer_num
==
1
:
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
==
0
]
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
==
0
]
elif
max_answer_num
==
2
:
elif
max_answer_num
==
2
:
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
in
[
0
,
1
]]
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
in
[
0
,
1
]]
elif
max_answer_num
==
3
:
elif
max_answer_num
==
3
:
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
in
[
0
,
1
,
2
]]
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
in
[
0
,
1
,
2
]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
# Output: A dict; key is coder, value is edit dict.
def
process_edits
(
edits
,
args
):
def
process_edits
(
edits
,
args
):
coder_dict
=
{}
coder_dict
=
{}
# Add an explicit noop edit if there are no edits.
# Add an explicit noop edit if there are no edits.
if
not
edits
:
edits
=
[[
-
1
,
-
1
,
"noop"
,
"-NONE-"
,
0
]]
if
not
edits
:
edits
=
[[
-
1
,
-
1
,
"noop"
,
"-NONE-"
,
0
]]
# Loop through the edits
# Loop through the edits
for
edit
in
edits
:
for
edit
in
edits
:
# Name the edit elements for clarity
# Name the edit elements for clarity
start
=
edit
[
0
]
start
=
edit
[
0
]
end
=
edit
[
1
]
end
=
edit
[
1
]
cat
=
edit
[
2
]
cat
=
edit
[
2
]
cor
=
edit
[
3
]
cor
=
edit
[
3
]
coder
=
edit
[
4
]
coder
=
edit
[
4
]
# Add the coder to the coder_dict if necessary
# Add the coder to the coder_dict if necessary
if
coder
not
in
coder_dict
:
coder_dict
[
coder
]
=
{}
if
coder
not
in
coder_dict
:
coder_dict
[
coder
]
=
{}
# Optionally apply filters based on args
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
# 1. UNK type edits are only useful for detection, not correction.
if
not
args
.
dt
and
not
args
.
ds
and
cat
==
"UNK"
:
continue
if
not
args
.
dt
and
not
args
.
ds
and
cat
==
"UNK"
:
continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if
args
.
single
and
(
end
-
start
>=
2
or
len
(
cor
.
split
())
>=
2
):
continue
if
args
.
single
and
(
end
-
start
>=
2
or
len
(
cor
.
split
())
>=
2
):
continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if
args
.
multi
and
end
-
start
<
2
and
len
(
cor
.
split
())
<
2
:
continue
if
args
.
multi
and
end
-
start
<
2
and
len
(
cor
.
split
())
<
2
:
continue
# 4. If there is a filter, ignore the specified error types
# 4. If there is a filter, ignore the specified error types
if
args
.
filt
and
cat
in
args
.
filt
:
continue
if
args
.
filt
and
cat
in
args
.
filt
:
continue
# Token Based Detection
# Token Based Detection
if
args
.
dt
:
if
args
.
dt
:
# Preserve noop edits.
# Preserve noop edits.
if
start
==
-
1
:
if
start
==
-
1
:
if
(
start
,
start
)
in
coder_dict
[
coder
].
keys
():
if
(
start
,
start
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
start
)].
append
(
cat
)
coder_dict
[
coder
][(
start
,
start
)].
append
(
cat
)
else
:
else
:
coder_dict
[
coder
][(
start
,
start
)]
=
[
cat
]
coder_dict
[
coder
][(
start
,
start
)]
=
[
cat
]
# Insertions defined as affecting the token on the right
# Insertions defined as affecting the token on the right
elif
start
==
end
and
start
>=
0
:
elif
start
==
end
and
start
>=
0
:
if
(
start
,
start
+
1
)
in
coder_dict
[
coder
].
keys
():
if
(
start
,
start
+
1
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
start
+
1
)].
append
(
cat
)
coder_dict
[
coder
][(
start
,
start
+
1
)].
append
(
cat
)
else
:
else
:
coder_dict
[
coder
][(
start
,
start
+
1
)]
=
[
cat
]
coder_dict
[
coder
][(
start
,
start
+
1
)]
=
[
cat
]
# Edit spans are split for each token in the range.
# Edit spans are split for each token in the range.
else
:
else
:
for
tok_id
in
range
(
start
,
end
):
for
tok_id
in
range
(
start
,
end
):
if
(
tok_id
,
tok_id
+
1
)
in
coder_dict
[
coder
].
keys
():
if
(
tok_id
,
tok_id
+
1
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
tok_id
,
tok_id
+
1
)].
append
(
cat
)
coder_dict
[
coder
][(
tok_id
,
tok_id
+
1
)].
append
(
cat
)
else
:
else
:
coder_dict
[
coder
][(
tok_id
,
tok_id
+
1
)]
=
[
cat
]
coder_dict
[
coder
][(
tok_id
,
tok_id
+
1
)]
=
[
cat
]
# Span Based Detection
# Span Based Detection
elif
args
.
ds
:
elif
args
.
ds
:
if
(
start
,
end
)
in
coder_dict
[
coder
].
keys
():
if
(
start
,
end
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
end
)].
append
(
cat
)
coder_dict
[
coder
][(
start
,
end
)].
append
(
cat
)
else
:
else
:
coder_dict
[
coder
][(
start
,
end
)]
=
[
cat
]
coder_dict
[
coder
][(
start
,
end
)]
=
[
cat
]
# Span Based Correction
# Span Based Correction
else
:
else
:
# With error type classification
# With error type classification
if
args
.
cse
:
if
args
.
cse
:
if
(
start
,
end
,
cat
,
cor
)
in
coder_dict
[
coder
].
keys
():
if
(
start
,
end
,
cat
,
cor
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
end
,
cat
,
cor
)].
append
(
cat
)
coder_dict
[
coder
][(
start
,
end
,
cat
,
cor
)].
append
(
cat
)
else
:
else
:
coder_dict
[
coder
][(
start
,
end
,
cat
,
cor
)]
=
[
cat
]
coder_dict
[
coder
][(
start
,
end
,
cat
,
cor
)]
=
[
cat
]
# Without error type classification
# Without error type classification
else
:
else
:
if
(
start
,
end
,
cor
)
in
coder_dict
[
coder
].
keys
():
if
(
start
,
end
,
cor
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
end
,
cor
)].
append
(
cat
)
coder_dict
[
coder
][(
start
,
end
,
cor
)].
append
(
cat
)
else
:
else
:
coder_dict
[
coder
][(
start
,
end
,
cor
)]
=
[
cat
]
coder_dict
[
coder
][(
start
,
end
,
cor
)]
=
[
cat
]
return
coder_dict
return
coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
# Output 2: The corresponding error type dict for the above dict.
def
evaluate_edits
(
src
,
hyp_dict
,
ref_dict
,
best
,
sent_id
,
args
):
def
evaluate_edits
(
src
,
hyp_dict
,
ref_dict
,
best
,
sent_id
,
args
):
# Store the best sentence level scores and hyp+ref combination IDs
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
# best_f is initialised as -1 cause 0 is a valid result.
best_tp
,
best_fp
,
best_fn
,
best_f
,
best_hyp
,
best_ref
=
0
,
0
,
0
,
-
1
,
0
,
0
best_tp
,
best_fp
,
best_fn
,
best_f
,
best_hyp
,
best_ref
=
0
,
0
,
0
,
-
1
,
0
,
0
best_cat
=
{}
best_cat
=
{}
# skip not annotatable sentence
# skip not annotatable sentence
if
len
(
ref_dict
.
keys
())
==
1
:
if
len
(
ref_dict
.
keys
())
==
1
:
ref_id
=
list
(
ref_dict
.
keys
())[
0
]
ref_id
=
list
(
ref_dict
.
keys
())[
0
]
if
len
(
ref_dict
[
ref_id
].
keys
())
==
1
:
if
len
(
ref_dict
[
ref_id
].
keys
())
==
1
:
cat
=
list
(
ref_dict
[
ref_id
].
values
())[
0
][
0
]
cat
=
list
(
ref_dict
[
ref_id
].
values
())[
0
][
0
]
if
cat
==
"NA"
:
if
cat
==
"NA"
:
best_dict
=
{
"tp"
:
best_tp
,
"fp"
:
best_fp
,
"fn"
:
best_fn
}
best_dict
=
{
"tp"
:
best_tp
,
"fp"
:
best_fp
,
"fn"
:
best_fn
}
return
best_dict
,
best_cat
return
best_dict
,
best_cat
# Compare each hyp and ref combination
# Compare each hyp and ref combination
for
hyp_id
in
hyp_dict
.
keys
():
for
hyp_id
in
hyp_dict
.
keys
():
for
ref_id
in
ref_dict
.
keys
():
for
ref_id
in
ref_dict
.
keys
():
# Get the local counts for the current combination.
# Get the local counts for the current combination.
tp
,
fp
,
fn
,
cat_dict
=
compareEdits
(
hyp_dict
[
hyp_id
],
ref_dict
[
ref_id
])
tp
,
fp
,
fn
,
cat_dict
=
compareEdits
(
hyp_dict
[
hyp_id
],
ref_dict
[
ref_id
])
# Compute the local sentence scores (for verbose output only)
# Compute the local sentence scores (for verbose output only)
loc_p
,
loc_r
,
loc_f
=
computeFScore
(
tp
,
fp
,
fn
,
args
.
beta
)
loc_p
,
loc_r
,
loc_f
=
computeFScore
(
tp
,
fp
,
fn
,
args
.
beta
)
# Compute the global sentence scores
# Compute the global sentence scores
p
,
r
,
f
=
computeFScore
(
p
,
r
,
f
=
computeFScore
(
tp
+
best
[
"tp"
],
fp
+
best
[
"fp"
],
fn
+
best
[
"fn"
],
args
.
beta
)
tp
+
best
[
"tp"
],
fp
+
best
[
"fp"
],
fn
+
best
[
"fn"
],
args
.
beta
)
# Save the scores if they are better in terms of:
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 1. Higher F-score
# 2. Same F-score, higher TP
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
# 4. Same F-score, TP and FP, lower FN
if
(
f
>
best_f
)
or
\
if
(
f
>
best_f
)
or
\
(
f
==
best_f
and
tp
>
best_tp
)
or
\
(
f
==
best_f
and
tp
>
best_tp
)
or
\
(
f
==
best_f
and
tp
==
best_tp
and
fp
<
best_fp
)
or
\
(
f
==
best_f
and
tp
==
best_tp
and
fp
<
best_fp
)
or
\
(
f
==
best_f
and
tp
==
best_tp
and
fp
==
best_fp
and
fn
<
best_fn
):
(
f
==
best_f
and
tp
==
best_tp
and
fp
==
best_fp
and
fn
<
best_fn
):
best_tp
,
best_fp
,
best_fn
=
tp
,
fp
,
fn
best_tp
,
best_fp
,
best_fn
=
tp
,
fp
,
fn
best_f
,
best_hyp
,
best_ref
=
f
,
hyp_id
,
ref_id
best_f
,
best_hyp
,
best_ref
=
f
,
hyp_id
,
ref_id
best_cat
=
cat_dict
best_cat
=
cat_dict
# Verbose output
# Verbose output
if
args
.
verbose
:
if
args
.
verbose
:
# Prepare verbose output edits.
# Prepare verbose output edits.
hyp_verb
=
list
(
sorted
(
hyp_dict
[
hyp_id
].
keys
()))
hyp_verb
=
list
(
sorted
(
hyp_dict
[
hyp_id
].
keys
()))
ref_verb
=
list
(
sorted
(
ref_dict
[
ref_id
].
keys
()))
ref_verb
=
list
(
sorted
(
ref_dict
[
ref_id
].
keys
()))
# Ignore noop edits
# Ignore noop edits
if
not
hyp_verb
or
hyp_verb
[
0
][
0
]
==
-
1
:
hyp_verb
=
[]
if
not
hyp_verb
or
hyp_verb
[
0
][
0
]
==
-
1
:
hyp_verb
=
[]
if
not
ref_verb
or
ref_verb
[
0
][
0
]
==
-
1
:
ref_verb
=
[]
if
not
ref_verb
or
ref_verb
[
0
][
0
]
==
-
1
:
ref_verb
=
[]
# Print verbose info
# Print verbose info
print
(
'{:-^40}'
.
format
(
""
))
print
(
'{:-^40}'
.
format
(
""
))
print
(
"SENTENCE "
+
str
(
sent_id
)
+
src
[
1
:])
print
(
"SENTENCE "
+
str
(
sent_id
)
+
src
[
1
:])
print
(
'{:-^40}'
.
format
(
""
))
print
(
'{:-^40}'
.
format
(
""
))
print
(
"SENTENCE "
+
str
(
sent_id
)
+
" - HYP "
+
str
(
hyp_id
)
+
" - REF "
+
str
(
ref_id
))
print
(
"SENTENCE "
+
str
(
sent_id
)
+
" - HYP "
+
str
(
hyp_id
)
+
" - REF "
+
str
(
ref_id
))
print
(
"HYPOTHESIS EDITS :"
,
hyp_verb
)
print
(
"HYPOTHESIS EDITS :"
,
hyp_verb
)
print
(
"REFERENCE EDITS :"
,
ref_verb
)
print
(
"REFERENCE EDITS :"
,
ref_verb
)
print
(
"Local TP/FP/FN :"
,
str
(
tp
),
str
(
fp
),
str
(
fn
))
print
(
"Local TP/FP/FN :"
,
str
(
tp
),
str
(
fp
),
str
(
fn
))
print
(
"Local P/R/F"
+
str
(
args
.
beta
)
+
" :"
,
str
(
loc_p
),
str
(
loc_r
),
str
(
loc_f
))
print
(
"Local P/R/F"
+
str
(
args
.
beta
)
+
" :"
,
str
(
loc_p
),
str
(
loc_r
),
str
(
loc_f
))
print
(
"Global TP/FP/FN :"
,
str
(
tp
+
best
[
"tp"
]),
str
(
fp
+
best
[
"fp"
]),
str
(
fn
+
best
[
"fn"
]))
print
(
"Global TP/FP/FN :"
,
str
(
tp
+
best
[
"tp"
]),
str
(
fp
+
best
[
"fp"
]),
str
(
fn
+
best
[
"fn"
]))
print
(
"Global P/R/F"
+
str
(
args
.
beta
)
+
" :"
,
str
(
p
),
str
(
r
),
str
(
f
))
print
(
"Global P/R/F"
+
str
(
args
.
beta
)
+
" :"
,
str
(
p
),
str
(
r
),
str
(
f
))
# Verbose output: display the best hyp+ref combination
# Verbose output: display the best hyp+ref combination
if
args
.
verbose
:
if
args
.
verbose
:
print
(
'{:-^40}'
.
format
(
""
))
print
(
'{:-^40}'
.
format
(
""
))
print
(
"^^ HYP "
+
str
(
best_hyp
)
+
", REF "
+
str
(
best_ref
)
+
" chosen for sentence "
+
str
(
sent_id
))
print
(
"^^ HYP "
+
str
(
best_hyp
)
+
", REF "
+
str
(
best_ref
)
+
" chosen for sentence "
+
str
(
sent_id
))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict
=
{
"tp"
:
best_tp
,
"fp"
:
best_fp
,
"fn"
:
best_fn
}
best_dict
=
{
"tp"
:
best_tp
,
"fp"
:
best_fp
,
"fn"
:
best_fn
}
return
best_dict
,
best_cat
return
best_dict
,
best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
# Output 4: A dictionary of the error type counts.
def
compareEdits
(
hyp_edits
,
ref_edits
):
def
compareEdits
(
hyp_edits
,
ref_edits
):
tp
=
0
# True Positives
tp
=
0
# True Positives
fp
=
0
# False Positives
fp
=
0
# False Positives
fn
=
0
# False Negatives
fn
=
0
# False Negatives
cat_dict
=
{}
# {cat: [tp, fp, fn], ...}
cat_dict
=
{}
# {cat: [tp, fp, fn], ...}
for
h_edit
,
h_cats
in
hyp_edits
.
items
():
for
h_edit
,
h_cats
in
hyp_edits
.
items
():
# noop hyp edits cannot be TP or FP
# noop hyp edits cannot be TP or FP
if
h_cats
[
0
]
==
"noop"
:
continue
if
h_cats
[
0
]
==
"noop"
:
continue
# TRUE POSITIVES
# TRUE POSITIVES
if
h_edit
in
ref_edits
.
keys
():
if
h_edit
in
ref_edits
.
keys
():
# On occasion, multiple tokens at same span.
# On occasion, multiple tokens at same span.
for
h_cat
in
ref_edits
[
h_edit
]:
# Use ref dict for TP
for
h_cat
in
ref_edits
[
h_edit
]:
# Use ref dict for TP
tp
+=
1
tp
+=
1
# Each dict value [TP, FP, FN]
# Each dict value [TP, FP, FN]
if
h_cat
in
cat_dict
.
keys
():
if
h_cat
in
cat_dict
.
keys
():
cat_dict
[
h_cat
][
0
]
+=
1
cat_dict
[
h_cat
][
0
]
+=
1
else
:
else
:
cat_dict
[
h_cat
]
=
[
1
,
0
,
0
]
cat_dict
[
h_cat
]
=
[
1
,
0
,
0
]
# FALSE POSITIVES
# FALSE POSITIVES
else
:
else
:
# On occasion, multiple tokens at same span.
# On occasion, multiple tokens at same span.
for
h_cat
in
h_cats
:
for
h_cat
in
h_cats
:
fp
+=
1
fp
+=
1
# Each dict value [TP, FP, FN]
# Each dict value [TP, FP, FN]
if
h_cat
in
cat_dict
.
keys
():
if
h_cat
in
cat_dict
.
keys
():
cat_dict
[
h_cat
][
1
]
+=
1
cat_dict
[
h_cat
][
1
]
+=
1
else
:
else
:
cat_dict
[
h_cat
]
=
[
0
,
1
,
0
]
cat_dict
[
h_cat
]
=
[
0
,
1
,
0
]
for
r_edit
,
r_cats
in
ref_edits
.
items
():
for
r_edit
,
r_cats
in
ref_edits
.
items
():
# noop ref edits cannot be FN
# noop ref edits cannot be FN
if
r_cats
[
0
]
==
"noop"
:
continue
if
r_cats
[
0
]
==
"noop"
:
continue
# FALSE NEGATIVES
# FALSE NEGATIVES
if
r_edit
not
in
hyp_edits
.
keys
():
if
r_edit
not
in
hyp_edits
.
keys
():
# On occasion, multiple tokens at same span.
# On occasion, multiple tokens at same span.
for
r_cat
in
r_cats
:
for
r_cat
in
r_cats
:
fn
+=
1
fn
+=
1
# Each dict value [TP, FP, FN]
# Each dict value [TP, FP, FN]
if
r_cat
in
cat_dict
.
keys
():
if
r_cat
in
cat_dict
.
keys
():
cat_dict
[
r_cat
][
2
]
+=
1
cat_dict
[
r_cat
][
2
]
+=
1
else
:
else
:
cat_dict
[
r_cat
]
=
[
0
,
0
,
1
]
cat_dict
[
r_cat
]
=
[
0
,
0
,
1
]
return
tp
,
fp
,
fn
,
cat_dict
return
tp
,
fp
,
fn
,
cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def
computeFScore
(
tp
,
fp
,
fn
,
beta
):
def
computeFScore
(
tp
,
fp
,
fn
,
beta
):
p
=
float
(
tp
)
/
(
tp
+
fp
)
if
fp
else
1.0
p
=
float
(
tp
)
/
(
tp
+
fp
)
if
fp
else
1.0
r
=
float
(
tp
)
/
(
tp
+
fn
)
if
fn
else
1.0
r
=
float
(
tp
)
/
(
tp
+
fn
)
if
fn
else
1.0
f
=
float
((
1
+
(
beta
**
2
))
*
p
*
r
)
/
(((
beta
**
2
)
*
p
)
+
r
)
if
p
+
r
else
0.0
f
=
float
((
1
+
(
beta
**
2
))
*
p
*
r
)
/
(((
beta
**
2
)
*
p
)
+
r
)
if
p
+
r
else
0.0
return
round
(
p
,
4
),
round
(
r
,
4
),
round
(
f
,
4
)
return
round
(
p
,
4
),
round
(
r
,
4
),
round
(
f
,
4
)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def
merge_dict
(
dict1
,
dict2
):
def
merge_dict
(
dict1
,
dict2
):
for
cat
,
stats
in
dict2
.
items
():
for
cat
,
stats
in
dict2
.
items
():
if
cat
in
dict1
.
keys
():
if
cat
in
dict1
.
keys
():
dict1
[
cat
]
=
[
x
+
y
for
x
,
y
in
zip
(
dict1
[
cat
],
stats
)]
dict1
[
cat
]
=
[
x
+
y
for
x
,
y
in
zip
(
dict1
[
cat
],
stats
)]
else
:
else
:
dict1
[
cat
]
=
stats
dict1
[
cat
]
=
stats
return
dict1
return
dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def
processCategories
(
cat_dict
,
setting
):
def
processCategories
(
cat_dict
,
setting
):
# Otherwise, do some processing.
# Otherwise, do some processing.
proc_cat_dict
=
{}
proc_cat_dict
=
{}
for
cat
,
cnt
in
cat_dict
.
items
():
for
cat
,
cnt
in
cat_dict
.
items
():
if
cat
==
"UNK"
:
if
cat
==
"UNK"
:
proc_cat_dict
[
cat
]
=
cnt
proc_cat_dict
[
cat
]
=
cnt
continue
continue
# M, U, R or UNK combined only.
# M, U, R or UNK combined only.
if
setting
==
1
:
if
setting
==
1
:
if
cat
[
0
]
in
proc_cat_dict
.
keys
():
if
cat
[
0
]
in
proc_cat_dict
.
keys
():
proc_cat_dict
[
cat
[
0
]]
=
[
x
+
y
for
x
,
y
in
zip
(
proc_cat_dict
[
cat
[
0
]],
cnt
)]
proc_cat_dict
[
cat
[
0
]]
=
[
x
+
y
for
x
,
y
in
zip
(
proc_cat_dict
[
cat
[
0
]],
cnt
)]
else
:
else
:
proc_cat_dict
[
cat
[
0
]]
=
cnt
proc_cat_dict
[
cat
[
0
]]
=
cnt
# Everything without M, U or R.
# Everything without M, U or R.
elif
setting
==
2
:
elif
setting
==
2
:
if
cat
[
2
:]
in
proc_cat_dict
.
keys
():
if
cat
[
2
:]
in
proc_cat_dict
.
keys
():
proc_cat_dict
[
cat
[
2
:]]
=
[
x
+
y
for
x
,
y
in
zip
(
proc_cat_dict
[
cat
[
2
:]],
cnt
)]
proc_cat_dict
[
cat
[
2
:]]
=
[
x
+
y
for
x
,
y
in
zip
(
proc_cat_dict
[
cat
[
2
:]],
cnt
)]
else
:
else
:
proc_cat_dict
[
cat
[
2
:]]
=
cnt
proc_cat_dict
[
cat
[
2
:]]
=
cnt
# All error category combinations
# All error category combinations
else
:
else
:
return
cat_dict
return
cat_dict
return
proc_cat_dict
return
proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
# Input 3: Command line args
def
print_results
(
best
,
best_cats
,
args
):
def
print_results
(
best
,
best_cats
,
args
):
# Prepare output title.
# Prepare output title.
if
args
.
dt
:
title
=
" Token-Based Detection "
if
args
.
dt
:
title
=
" Token-Based Detection "
elif
args
.
ds
:
title
=
" Span-Based Detection "
elif
args
.
ds
:
title
=
" Span-Based Detection "
elif
args
.
cse
:
title
=
" Span-Based Correction + Classification "
elif
args
.
cse
:
title
=
" Span-Based Correction + Classification "
else
:
title
=
" Span-Based Correction "
else
:
title
=
" Span-Based Correction "
# Category Scores
# Category Scores
if
args
.
cat
:
if
args
.
cat
:
best_cats
=
processCategories
(
best_cats
,
args
.
cat
)
best_cats
=
processCategories
(
best_cats
,
args
.
cat
)
print
(
""
)
print
(
""
)
print
(
'{:=^66}'
.
format
(
title
))
print
(
'{:=^66}'
.
format
(
title
))
print
(
"Category"
.
ljust
(
14
),
"TP"
.
ljust
(
8
),
"FP"
.
ljust
(
8
),
"FN"
.
ljust
(
8
),
print
(
"Category"
.
ljust
(
14
),
"TP"
.
ljust
(
8
),
"FP"
.
ljust
(
8
),
"FN"
.
ljust
(
8
),
"P"
.
ljust
(
8
),
"R"
.
ljust
(
8
),
"F"
+
str
(
args
.
beta
))
"P"
.
ljust
(
8
),
"R"
.
ljust
(
8
),
"F"
+
str
(
args
.
beta
))
for
cat
,
cnts
in
sorted
(
best_cats
.
items
()):
for
cat
,
cnts
in
sorted
(
best_cats
.
items
()):
cat_p
,
cat_r
,
cat_f
=
computeFScore
(
cnts
[
0
],
cnts
[
1
],
cnts
[
2
],
args
.
beta
)
cat_p
,
cat_r
,
cat_f
=
computeFScore
(
cnts
[
0
],
cnts
[
1
],
cnts
[
2
],
args
.
beta
)
print
(
cat
.
ljust
(
14
),
str
(
cnts
[
0
]).
ljust
(
8
),
str
(
cnts
[
1
]).
ljust
(
8
),
print
(
cat
.
ljust
(
14
),
str
(
cnts
[
0
]).
ljust
(
8
),
str
(
cnts
[
1
]).
ljust
(
8
),
str
(
cnts
[
2
]).
ljust
(
8
),
str
(
cat_p
).
ljust
(
8
),
str
(
cat_r
).
ljust
(
8
),
cat_f
)
str
(
cnts
[
2
]).
ljust
(
8
),
str
(
cat_p
).
ljust
(
8
),
str
(
cat_r
).
ljust
(
8
),
cat_f
)
# Print the overall results.
# Print the overall results.
print
(
""
)
print
(
""
)
print
(
'{:=^46}'
.
format
(
title
))
print
(
'{:=^46}'
.
format
(
title
))
print
(
"
\t
"
.
join
([
"TP"
,
"FP"
,
"FN"
,
"Prec"
,
"Rec"
,
"F"
+
str
(
args
.
beta
)]))
print
(
"
\t
"
.
join
([
"TP"
,
"FP"
,
"FN"
,
"Prec"
,
"Rec"
,
"F"
+
str
(
args
.
beta
)]))
print
(
"
\t
"
.
join
(
map
(
str
,
[
best
[
"tp"
],
best
[
"fp"
],
print
(
"
\t
"
.
join
(
map
(
str
,
[
best
[
"tp"
],
best
[
"fp"
],
best
[
"fn"
]]
+
list
(
computeFScore
(
best
[
"tp"
],
best
[
"fp"
],
best
[
"fn"
],
args
.
beta
)))))
best
[
"fn"
]]
+
list
(
computeFScore
(
best
[
"tp"
],
best
[
"fp"
],
best
[
"fn"
],
args
.
beta
)))))
print
(
'{:=^46}'
.
format
(
""
))
print
(
'{:=^46}'
.
format
(
""
))
print
(
""
)
print
(
""
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# Run the program
# Run the program
main
()
main
()
opencompass/datasets/lawbench/utils/function_utils.py
View file @
4dd9a3fc
from
rouge_chinese
import
Rouge
from
rouge_chinese
import
Rouge
import
jieba
import
jieba
from
nltk.translate.gleu_score
import
corpus_gleu
from
nltk.translate.gleu_score
import
corpus_gleu
def
compute_f1_two_sets
(
pred_set
,
gt_set
):
def
compute_f1_two_sets
(
pred_set
,
gt_set
):
precision
=
len
(
pred_set
.
intersection
(
gt_set
))
/
len
(
pred_set
)
if
len
(
pred_set
)
>
0
else
0
precision
=
len
(
pred_set
.
intersection
(
gt_set
))
/
len
(
pred_set
)
if
len
(
pred_set
)
>
0
else
0
recall
=
len
(
pred_set
.
intersection
(
gt_set
))
/
len
(
gt_set
)
if
len
(
gt_set
)
>
0
else
0
recall
=
len
(
pred_set
.
intersection
(
gt_set
))
/
len
(
gt_set
)
if
len
(
gt_set
)
>
0
else
0
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0
return
f1
return
f1
def
multi_choice_judge
(
prediction
,
option_list
,
answer_token
):
def
multi_choice_judge
(
prediction
,
option_list
,
answer_token
):
# a dict, key: letters in the option list, value: count of the letter in the prediction
# a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict
,
abstention
,
accuracy
=
{},
0
,
0
count_dict
,
abstention
,
accuracy
=
{},
0
,
0
for
option
in
option_list
:
for
option
in
option_list
:
option_count
=
prediction
.
count
(
option
)
option_count
=
prediction
.
count
(
option
)
count_dict
[
option
]
=
1
if
option_count
>
0
else
0
# multiple occurrence of the same letter is counted as 1
count_dict
[
option
]
=
1
if
option_count
>
0
else
0
# multiple occurrence of the same letter is counted as 1
if
sum
(
count_dict
.
values
())
==
0
:
if
sum
(
count_dict
.
values
())
==
0
:
abstention
=
1
abstention
=
1
# if the answer token is the only predicted token, the prediction is correct
# if the answer token is the only predicted token, the prediction is correct
elif
count_dict
[
answer_token
]
==
1
and
sum
(
count_dict
.
values
())
==
1
:
elif
count_dict
[
answer_token
]
==
1
and
sum
(
count_dict
.
values
())
==
1
:
accuracy
=
1
accuracy
=
1
return
{
"score"
:
accuracy
,
"abstention"
:
abstention
}
return
{
"score"
:
accuracy
,
"abstention"
:
abstention
}
"""
"""
compute the rouge score.
compute the rouge score.
hyps and refs are lists of hyposisis and reference strings
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
empty predictions are replaces with 无内容
"""
"""
def
compute_rouge
(
hyps
,
refs
):
def
compute_rouge
(
hyps
,
refs
):
assert
(
len
(
hyps
)
==
len
(
refs
))
assert
(
len
(
hyps
)
==
len
(
refs
))
hyps
=
[
' '
.
join
(
jieba
.
cut
(
h
))
for
h
in
hyps
]
hyps
=
[
' '
.
join
(
jieba
.
cut
(
h
))
for
h
in
hyps
]
hyps
=
[
h
if
h
.
strip
()
!=
""
else
"无内容"
for
h
in
hyps
]
hyps
=
[
h
if
h
.
strip
()
!=
""
else
"无内容"
for
h
in
hyps
]
refs
=
[
' '
.
join
(
jieba
.
cut
(
r
))
for
r
in
refs
]
refs
=
[
' '
.
join
(
jieba
.
cut
(
r
))
for
r
in
refs
]
return
Rouge
().
get_scores
(
hyps
,
refs
)
return
Rouge
().
get_scores
(
hyps
,
refs
)
"""
"""
compute the gleu score.
compute the gleu score.
hyps and refs are lists of hyposisis and reference strings
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
empty predictions are replaces with 无内容
"""
"""
def
compute_gleu
(
hyps
,
refs
):
def
compute_gleu
(
hyps
,
refs
):
assert
(
len
(
hyps
)
==
len
(
refs
))
assert
(
len
(
hyps
)
==
len
(
refs
))
hyps
=
[
' '
.
join
(
jieba
.
cut
(
h
))
for
h
in
hyps
]
hyps
=
[
' '
.
join
(
jieba
.
cut
(
h
))
for
h
in
hyps
]
hyps
=
[
h
if
h
.
strip
()
!=
""
else
"无内容"
for
h
in
hyps
]
hyps
=
[
h
if
h
.
strip
()
!=
""
else
"无内容"
for
h
in
hyps
]
refs
=
[[
' '
.
join
(
jieba
.
cut
(
r
))]
for
r
in
refs
]
refs
=
[[
' '
.
join
(
jieba
.
cut
(
r
))]
for
r
in
refs
]
return
corpus_gleu
(
refs
,
hyps
)
return
corpus_gleu
(
refs
,
hyps
)
opencompass/datasets/lawbench/utils/modules/alignment.py
View file @
4dd9a3fc
import
numpy
as
np
import
numpy
as
np
from
typing
import
List
,
Tuple
,
Dict
from
typing
import
List
,
Tuple
,
Dict
from
modules.tokenizer
import
Tokenizer
from
modules.tokenizer
import
Tokenizer
import
os
import
os
from
string
import
punctuation
from
string
import
punctuation
REAL_PATH
=
os
.
path
.
split
(
os
.
path
.
realpath
(
__file__
))[
0
]
REAL_PATH
=
os
.
path
.
split
(
os
.
path
.
realpath
(
__file__
))[
0
]
chinese_punct
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
chinese_punct
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct
=
punctuation
english_punct
=
punctuation
punct
=
chinese_punct
+
english_punct
punct
=
chinese_punct
+
english_punct
def
check_all_chinese
(
word
):
def
check_all_chinese
(
word
):
"""
"""
判断一个单词是否全部由中文组成
判断一个单词是否全部由中文组成
:param word:
:param word:
:return:
:return:
"""
"""
return
all
([
'
\u4e00
'
<=
ch
<=
'
\u9fff
'
for
ch
in
word
])
return
all
([
'
\u4e00
'
<=
ch
<=
'
\u9fff
'
for
ch
in
word
])
def
read_cilin
():
def
read_cilin
():
"""
"""
Cilin 詞林 is a thesaurus with semantic information
Cilin 詞林 is a thesaurus with semantic information
"""
"""
# TODO -- fix this path
# TODO -- fix this path
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
# ymliu@2023.5.30 fix the path
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
# ymliu@2023.5.30 fix the path
lines
=
open
(
os
.
path
.
join
(
project_dir
,
"data"
,
"cilin.txt"
),
"r"
,
encoding
=
"gbk"
).
read
().
strip
().
split
(
"
\n
"
)
lines
=
open
(
os
.
path
.
join
(
project_dir
,
"data"
,
"cilin.txt"
),
"r"
,
encoding
=
"gbk"
).
read
().
strip
().
split
(
"
\n
"
)
semantic_dict
=
{}
semantic_dict
=
{}
semantic_classes
=
{}
semantic_classes
=
{}
for
line
in
lines
:
for
line
in
lines
:
code
,
*
words
=
line
.
split
(
" "
)
code
,
*
words
=
line
.
split
(
" "
)
for
word
in
words
:
for
word
in
words
:
semantic_dict
[
word
]
=
code
semantic_dict
[
word
]
=
code
# make reverse dict
# make reverse dict
if
code
in
semantic_classes
:
if
code
in
semantic_classes
:
semantic_classes
[
code
]
+=
words
semantic_classes
[
code
]
+=
words
else
:
else
:
semantic_classes
[
code
]
=
words
semantic_classes
[
code
]
=
words
return
semantic_dict
,
semantic_classes
return
semantic_dict
,
semantic_classes
def
read_confusion
():
def
read_confusion
():
confusion_dict
=
{}
confusion_dict
=
{}
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
# ymliu@2023.5.30 fix the path
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
# ymliu@2023.5.30 fix the path
with
open
(
os
.
path
.
join
(
project_dir
,
"data"
,
"confusion_dict.txt"
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
project_dir
,
"data"
,
"confusion_dict.txt"
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
li
=
line
.
rstrip
(
'
\n
'
).
split
(
" "
)
li
=
line
.
rstrip
(
'
\n
'
).
split
(
" "
)
confusion_dict
[
li
[
0
]]
=
li
[
1
:]
confusion_dict
[
li
[
0
]]
=
li
[
1
:]
return
confusion_dict
return
confusion_dict
class
Alignment
:
class
Alignment
:
"""
"""
对齐错误句子和正确句子,
对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作
使用编辑距离算法抽取编辑操作
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
semantic_dict
:
Dict
,
semantic_dict
:
Dict
,
confusion_dict
:
Dict
,
confusion_dict
:
Dict
,
granularity
:
str
=
"word"
,
granularity
:
str
=
"word"
,
)
->
None
:
)
->
None
:
"""
"""
构造函数
构造函数
:param semantic_dict: 语义词典(大词林)
:param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集
:param confusion_dict: 字符混淆集
"""
"""
self
.
insertion_cost
=
1
self
.
insertion_cost
=
1
self
.
deletion_cost
=
1
self
.
deletion_cost
=
1
self
.
semantic_dict
=
semantic_dict
self
.
semantic_dict
=
semantic_dict
self
.
confusion_dict
=
confusion_dict
self
.
confusion_dict
=
confusion_dict
# Because we use character level tokenization, this doesn't currently use POS
# Because we use character level tokenization, this doesn't currently use POS
self
.
_open_pos
=
{}
# 如果是词级别,还可以利用词性是否相同来计算cost
self
.
_open_pos
=
{}
# 如果是词级别,还可以利用词性是否相同来计算cost
self
.
granularity
=
granularity
# word-level or character-level
self
.
granularity
=
granularity
# word-level or character-level
self
.
align_seqs
=
[]
self
.
align_seqs
=
[]
def
__call__
(
self
,
def
__call__
(
self
,
src
:
List
[
Tuple
],
src
:
List
[
Tuple
],
tgt
:
List
[
Tuple
],
tgt
:
List
[
Tuple
],
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
cost_matrix
,
oper_matrix
=
self
.
align
(
src
,
tgt
)
cost_matrix
,
oper_matrix
=
self
.
align
(
src
,
tgt
)
align_seq
=
self
.
get_cheapest_align_seq
(
oper_matrix
)
align_seq
=
self
.
get_cheapest_align_seq
(
oper_matrix
)
if
verbose
:
if
verbose
:
print
(
"========== Seg. and POS: =========="
)
print
(
"========== Seg. and POS: =========="
)
print
(
src
)
print
(
src
)
print
(
tgt
)
print
(
tgt
)
print
(
"========== Cost Matrix =========="
)
print
(
"========== Cost Matrix =========="
)
print
(
cost_matrix
)
print
(
cost_matrix
)
print
(
"========== Oper Matrix =========="
)
print
(
"========== Oper Matrix =========="
)
print
(
oper_matrix
)
print
(
oper_matrix
)
print
(
"========== Alignment =========="
)
print
(
"========== Alignment =========="
)
print
(
align_seq
)
print
(
align_seq
)
print
(
"========== Results =========="
)
print
(
"========== Results =========="
)
for
a
in
align_seq
:
for
a
in
align_seq
:
print
(
a
[
0
],
src
[
a
[
1
]:
a
[
2
]],
tgt
[
a
[
3
]:
a
[
4
]])
print
(
a
[
0
],
src
[
a
[
1
]:
a
[
2
]],
tgt
[
a
[
3
]:
a
[
4
]])
return
align_seq
return
align_seq
def
_get_semantic_class
(
self
,
word
):
def
_get_semantic_class
(
self
,
word
):
"""
"""
NOTE: Based on the paper:
NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction
Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别)
获取每个词语的语义类别(基于大词林,有三个级别)
"""
"""
if
word
in
self
.
semantic_dict
:
if
word
in
self
.
semantic_dict
:
code
=
self
.
semantic_dict
[
word
]
code
=
self
.
semantic_dict
[
word
]
high
,
mid
,
low
=
code
[
0
],
code
[
1
],
code
[
2
:
4
]
high
,
mid
,
low
=
code
[
0
],
code
[
1
],
code
[
2
:
4
]
return
high
,
mid
,
low
return
high
,
mid
,
low
else
:
# unknown
else
:
# unknown
return
None
return
None
@
staticmethod
@
staticmethod
def
_get_class_diff
(
a_class
,
b_class
):
def
_get_class_diff
(
a_class
,
b_class
):
"""
"""
d == 3 for equivalent semantics
d == 3 for equivalent semantics
d == 0 for completely different semantics
d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距
根据大词林的信息,计算两个词的语义类别的差距
"""
"""
d
=
sum
([
a
==
b
for
a
,
b
in
zip
(
a_class
,
b_class
)])
d
=
sum
([
a
==
b
for
a
,
b
in
zip
(
a_class
,
b_class
)])
return
d
return
d
def
_get_semantic_cost
(
self
,
a
,
b
):
def
_get_semantic_cost
(
self
,
a
,
b
):
"""
"""
计算基于语义信息的替换操作cost
计算基于语义信息的替换操作cost
:param a: 单词a的语义类别
:param a: 单词a的语义类别
:param b: 单词b的语义类别
:param b: 单词b的语义类别
:return: 替换编辑代价
:return: 替换编辑代价
"""
"""
a_class
=
self
.
_get_semantic_class
(
a
)
a_class
=
self
.
_get_semantic_class
(
a
)
b_class
=
self
.
_get_semantic_class
(
b
)
b_class
=
self
.
_get_semantic_class
(
b
)
# unknown class, default to 1
# unknown class, default to 1
if
a_class
is
None
or
b_class
is
None
:
if
a_class
is
None
or
b_class
is
None
:
return
4
return
4
elif
a_class
==
b_class
:
elif
a_class
==
b_class
:
return
0
return
0
else
:
else
:
return
2
*
(
3
-
self
.
_get_class_diff
(
a_class
,
b_class
))
return
2
*
(
3
-
self
.
_get_class_diff
(
a_class
,
b_class
))
def
_get_pos_cost
(
self
,
a_pos
,
b_pos
):
def
_get_pos_cost
(
self
,
a_pos
,
b_pos
):
"""
"""
计算基于词性信息的编辑距离cost
计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性
:param a_pos: 单词a的词性
:param b_pos: 单词b的词性
:param b_pos: 单词b的词性
:return: 替换编辑代价
:return: 替换编辑代价
"""
"""
if
a_pos
==
b_pos
:
if
a_pos
==
b_pos
:
return
0
return
0
elif
a_pos
in
self
.
_open_pos
and
b_pos
in
self
.
_open_pos
:
elif
a_pos
in
self
.
_open_pos
and
b_pos
in
self
.
_open_pos
:
return
0.25
return
0.25
else
:
else
:
return
0.499
return
0.499
def
_get_char_cost
(
self
,
a
,
b
,
pinyin_a
,
pinyin_b
):
def
_get_char_cost
(
self
,
a
,
b
,
pinyin_a
,
pinyin_b
):
"""
"""
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost
计算基于字符相似度的编辑距离cost
"""
"""
if
not
(
check_all_chinese
(
a
)
and
check_all_chinese
(
b
)):
if
not
(
check_all_chinese
(
a
)
and
check_all_chinese
(
b
)):
return
0.5
return
0.5
if
len
(
a
)
>
len
(
b
):
if
len
(
a
)
>
len
(
b
):
a
,
b
=
b
,
a
a
,
b
=
b
,
a
pinyin_a
,
pinyin_b
=
pinyin_b
,
pinyin_a
pinyin_a
,
pinyin_b
=
pinyin_b
,
pinyin_a
if
a
==
b
:
if
a
==
b
:
return
0
return
0
else
:
else
:
return
self
.
_get_spell_cost
(
a
,
b
,
pinyin_a
,
pinyin_b
)
return
self
.
_get_spell_cost
(
a
,
b
,
pinyin_a
,
pinyin_b
)
def
_get_spell_cost
(
self
,
a
,
b
,
pinyin_a
,
pinyin_b
):
def
_get_spell_cost
(
self
,
a
,
b
,
pinyin_a
,
pinyin_b
):
"""
"""
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a
:param a: 单词a
:param b: 单词b,且单词a的长度小于等于b
:param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音
:param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音
:param pinyin_b: 单词b的拼音
:return: 替换操作cost
:return: 替换操作cost
"""
"""
count
=
0
count
=
0
for
i
in
range
(
len
(
a
)):
for
i
in
range
(
len
(
a
)):
for
j
in
range
(
len
(
b
)):
for
j
in
range
(
len
(
b
)):
if
a
[
i
]
==
b
[
j
]
or
(
set
(
pinyin_a
)
&
set
(
pinyin_b
))
or
(
b
[
j
]
in
self
.
confusion_dict
.
keys
()
and
a
[
i
]
in
self
.
confusion_dict
[
b
[
j
]])
or
(
a
[
i
]
in
self
.
confusion_dict
.
keys
()
and
b
[
j
]
in
self
.
confusion_dict
[
a
[
i
]]):
if
a
[
i
]
==
b
[
j
]
or
(
set
(
pinyin_a
)
&
set
(
pinyin_b
))
or
(
b
[
j
]
in
self
.
confusion_dict
.
keys
()
and
a
[
i
]
in
self
.
confusion_dict
[
b
[
j
]])
or
(
a
[
i
]
in
self
.
confusion_dict
.
keys
()
and
b
[
j
]
in
self
.
confusion_dict
[
a
[
i
]]):
count
+=
1
count
+=
1
break
break
return
(
len
(
a
)
-
count
)
/
(
len
(
a
)
*
2
)
return
(
len
(
a
)
-
count
)
/
(
len
(
a
)
*
2
)
def
get_sub_cost
(
self
,
a_seg
,
b_seg
):
def
get_sub_cost
(
self
,
a_seg
,
b_seg
):
"""
"""
Calculate the substitution cost between words a and b
Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
"""
"""
if
a_seg
[
0
]
==
b_seg
[
0
]:
if
a_seg
[
0
]
==
b_seg
[
0
]:
return
0
return
0
if
self
.
granularity
==
"word"
:
# 词级别可以额外利用词性信息
if
self
.
granularity
==
"word"
:
# 词级别可以额外利用词性信息
semantic_cost
=
self
.
_get_semantic_cost
(
a_seg
[
0
],
b_seg
[
0
])
/
6.0
semantic_cost
=
self
.
_get_semantic_cost
(
a_seg
[
0
],
b_seg
[
0
])
/
6.0
pos_cost
=
self
.
_get_pos_cost
(
a_seg
[
1
],
b_seg
[
1
])
pos_cost
=
self
.
_get_pos_cost
(
a_seg
[
1
],
b_seg
[
1
])
char_cost
=
self
.
_get_char_cost
(
a_seg
[
0
],
b_seg
[
0
],
a_seg
[
2
],
b_seg
[
2
])
char_cost
=
self
.
_get_char_cost
(
a_seg
[
0
],
b_seg
[
0
],
a_seg
[
2
],
b_seg
[
2
])
return
semantic_cost
+
pos_cost
+
char_cost
return
semantic_cost
+
pos_cost
+
char_cost
else
:
# 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
else
:
# 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost
=
self
.
_get_semantic_cost
(
a_seg
[
0
],
b_seg
[
0
])
/
6.0
semantic_cost
=
self
.
_get_semantic_cost
(
a_seg
[
0
],
b_seg
[
0
])
/
6.0
if
a_seg
[
0
]
in
punct
and
b_seg
[
0
]
in
punct
:
if
a_seg
[
0
]
in
punct
and
b_seg
[
0
]
in
punct
:
pos_cost
=
0.0
pos_cost
=
0.0
elif
a_seg
[
0
]
not
in
punct
and
b_seg
[
0
]
not
in
punct
:
elif
a_seg
[
0
]
not
in
punct
and
b_seg
[
0
]
not
in
punct
:
pos_cost
=
0.25
pos_cost
=
0.25
else
:
else
:
pos_cost
=
0.499
pos_cost
=
0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost
=
self
.
_get_char_cost
(
a_seg
[
0
],
b_seg
[
0
],
a_seg
[
2
],
b_seg
[
2
])
char_cost
=
self
.
_get_char_cost
(
a_seg
[
0
],
b_seg
[
0
],
a_seg
[
2
],
b_seg
[
2
])
return
semantic_cost
+
char_cost
+
pos_cost
return
semantic_cost
+
char_cost
+
pos_cost
def
align
(
self
,
def
align
(
self
,
src
:
List
[
Tuple
],
src
:
List
[
Tuple
],
tgt
:
List
[
Tuple
]):
tgt
:
List
[
Tuple
]):
"""
"""
Based on ERRANT's alignment
Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别:
编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变
1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除
2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入
3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题
4) T:Transposition,移位操作,即涉及到词序问题
"""
"""
cost_matrix
=
np
.
zeros
((
len
(
src
)
+
1
,
len
(
tgt
)
+
1
))
# 编辑cost矩阵
cost_matrix
=
np
.
zeros
((
len
(
src
)
+
1
,
len
(
tgt
)
+
1
))
# 编辑cost矩阵
oper_matrix
=
np
.
full
(
oper_matrix
=
np
.
full
(
(
len
(
src
)
+
1
,
len
(
tgt
)
+
1
),
"O"
,
dtype
=
object
(
len
(
src
)
+
1
,
len
(
tgt
)
+
1
),
"O"
,
dtype
=
object
)
# 操作矩阵
)
# 操作矩阵
# Fill in the edges
# Fill in the edges
for
i
in
range
(
1
,
len
(
src
)
+
1
):
for
i
in
range
(
1
,
len
(
src
)
+
1
):
cost_matrix
[
i
][
0
]
=
cost_matrix
[
i
-
1
][
0
]
+
1
cost_matrix
[
i
][
0
]
=
cost_matrix
[
i
-
1
][
0
]
+
1
oper_matrix
[
i
][
0
]
=
[
"D"
]
oper_matrix
[
i
][
0
]
=
[
"D"
]
for
j
in
range
(
1
,
len
(
tgt
)
+
1
):
for
j
in
range
(
1
,
len
(
tgt
)
+
1
):
cost_matrix
[
0
][
j
]
=
cost_matrix
[
0
][
j
-
1
]
+
1
cost_matrix
[
0
][
j
]
=
cost_matrix
[
0
][
j
-
1
]
+
1
oper_matrix
[
0
][
j
]
=
[
"I"
]
oper_matrix
[
0
][
j
]
=
[
"I"
]
# Loop through the cost matrix
# Loop through the cost matrix
for
i
in
range
(
len
(
src
)):
for
i
in
range
(
len
(
src
)):
for
j
in
range
(
len
(
tgt
)):
for
j
in
range
(
len
(
tgt
)):
# Matches
# Matches
if
src
[
i
][
0
]
==
tgt
[
j
][
0
]:
# 如果两个字相等,则匹配成功(Match),编辑距离为0
if
src
[
i
][
0
]
==
tgt
[
j
][
0
]:
# 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix
[
i
+
1
][
j
+
1
]
=
cost_matrix
[
i
][
j
]
cost_matrix
[
i
+
1
][
j
+
1
]
=
cost_matrix
[
i
][
j
]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"M"
]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"M"
]
# Non-matches
# Non-matches
else
:
else
:
del_cost
=
cost_matrix
[
i
][
j
+
1
]
+
self
.
deletion_cost
# 由删除动作得到的总cost
del_cost
=
cost_matrix
[
i
][
j
+
1
]
+
self
.
deletion_cost
# 由删除动作得到的总cost
ins_cost
=
cost_matrix
[
i
+
1
][
j
]
+
self
.
insertion_cost
# 由插入动作得到的总cost
ins_cost
=
cost_matrix
[
i
+
1
][
j
]
+
self
.
insertion_cost
# 由插入动作得到的总cost
sub_cost
=
cost_matrix
[
i
][
j
]
+
self
.
get_sub_cost
(
sub_cost
=
cost_matrix
[
i
][
j
]
+
self
.
get_sub_cost
(
src
[
i
],
tgt
[
j
]
src
[
i
],
tgt
[
j
]
)
# 由替换动作得到的总cost
)
# 由替换动作得到的总cost
# Calculate transposition cost
# Calculate transposition cost
# 计算移位操作的总cost
# 计算移位操作的总cost
trans_cost
=
float
(
"inf"
)
trans_cost
=
float
(
"inf"
)
k
=
1
k
=
1
while
(
while
(
i
-
k
>=
0
i
-
k
>=
0
and
j
-
k
>=
0
and
j
-
k
>=
0
and
cost_matrix
[
i
-
k
+
1
][
j
-
k
+
1
]
and
cost_matrix
[
i
-
k
+
1
][
j
-
k
+
1
]
!=
cost_matrix
[
i
-
k
][
j
-
k
]
!=
cost_matrix
[
i
-
k
][
j
-
k
]
):
):
p1
=
sorted
([
a
[
0
]
for
a
in
src
][
i
-
k
:
i
+
1
])
p1
=
sorted
([
a
[
0
]
for
a
in
src
][
i
-
k
:
i
+
1
])
p2
=
sorted
([
b
[
0
]
for
b
in
tgt
][
j
-
k
:
j
+
1
])
p2
=
sorted
([
b
[
0
]
for
b
in
tgt
][
j
-
k
:
j
+
1
])
if
p1
==
p2
:
if
p1
==
p2
:
trans_cost
=
cost_matrix
[
i
-
k
][
j
-
k
]
+
k
trans_cost
=
cost_matrix
[
i
-
k
][
j
-
k
]
+
k
break
break
k
+=
1
k
+=
1
costs
=
[
trans_cost
,
sub_cost
,
ins_cost
,
del_cost
]
costs
=
[
trans_cost
,
sub_cost
,
ins_cost
,
del_cost
]
ind
=
costs
.
index
(
min
(
costs
))
ind
=
costs
.
index
(
min
(
costs
))
cost_matrix
[
i
+
1
][
j
+
1
]
=
costs
[
ind
]
cost_matrix
[
i
+
1
][
j
+
1
]
=
costs
[
ind
]
# ind = costs.index(costs[ind], ind+1)
# ind = costs.index(costs[ind], ind+1)
for
idx
,
cost
in
enumerate
(
costs
):
for
idx
,
cost
in
enumerate
(
costs
):
if
cost
==
costs
[
ind
]:
if
cost
==
costs
[
ind
]:
if
idx
==
0
:
if
idx
==
0
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"T"
+
str
(
k
+
1
)]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"T"
+
str
(
k
+
1
)]
else
:
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"T"
+
str
(
k
+
1
))
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"T"
+
str
(
k
+
1
))
elif
idx
==
1
:
elif
idx
==
1
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"S"
]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"S"
]
else
:
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"S"
)
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"S"
)
elif
idx
==
2
:
elif
idx
==
2
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"I"
]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"I"
]
else
:
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"I"
)
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"I"
)
else
:
else
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"D"
]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"D"
]
else
:
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"D"
)
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"D"
)
return
cost_matrix
,
oper_matrix
return
cost_matrix
,
oper_matrix
def
_dfs
(
self
,
i
,
j
,
align_seq_now
,
oper_matrix
,
strategy
=
"all"
):
def
_dfs
(
self
,
i
,
j
,
align_seq_now
,
oper_matrix
,
strategy
=
"all"
):
"""
"""
深度优先遍历,获取最小编辑距离相同的所有序列
深度优先遍历,获取最小编辑距离相同的所有序列
"""
"""
if
i
+
j
==
0
:
if
i
+
j
==
0
:
self
.
align_seqs
.
append
(
align_seq_now
)
self
.
align_seqs
.
append
(
align_seq_now
)
else
:
else
:
ops
=
oper_matrix
[
i
][
j
]
# 可以类比成搜索一棵树从根结点到叶子结点的所有路径
ops
=
oper_matrix
[
i
][
j
]
# 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if
strategy
!=
"all"
:
ops
=
ops
[:
1
]
if
strategy
!=
"all"
:
ops
=
ops
[:
1
]
for
op
in
ops
:
for
op
in
ops
:
if
op
in
{
"M"
,
"S"
}:
if
op
in
{
"M"
,
"S"
}:
self
.
_dfs
(
i
-
1
,
j
-
1
,
align_seq_now
+
[(
op
,
i
-
1
,
i
,
j
-
1
,
j
)],
oper_matrix
,
strategy
)
self
.
_dfs
(
i
-
1
,
j
-
1
,
align_seq_now
+
[(
op
,
i
-
1
,
i
,
j
-
1
,
j
)],
oper_matrix
,
strategy
)
elif
op
==
"D"
:
elif
op
==
"D"
:
self
.
_dfs
(
i
-
1
,
j
,
align_seq_now
+
[(
op
,
i
-
1
,
i
,
j
,
j
)],
oper_matrix
,
strategy
)
self
.
_dfs
(
i
-
1
,
j
,
align_seq_now
+
[(
op
,
i
-
1
,
i
,
j
,
j
)],
oper_matrix
,
strategy
)
elif
op
==
"I"
:
elif
op
==
"I"
:
self
.
_dfs
(
i
,
j
-
1
,
align_seq_now
+
[(
op
,
i
,
i
,
j
-
1
,
j
)],
oper_matrix
,
strategy
)
self
.
_dfs
(
i
,
j
-
1
,
align_seq_now
+
[(
op
,
i
,
i
,
j
-
1
,
j
)],
oper_matrix
,
strategy
)
else
:
else
:
k
=
int
(
op
[
1
:])
k
=
int
(
op
[
1
:])
self
.
_dfs
(
i
-
k
,
j
-
k
,
align_seq_now
+
[(
op
,
i
-
k
,
i
,
j
-
k
,
j
)],
oper_matrix
,
strategy
)
self
.
_dfs
(
i
-
k
,
j
-
k
,
align_seq_now
+
[(
op
,
i
-
k
,
i
,
j
-
k
,
j
)],
oper_matrix
,
strategy
)
def
get_cheapest_align_seq
(
self
,
oper_matrix
):
def
get_cheapest_align_seq
(
self
,
oper_matrix
):
"""
"""
回溯获得编辑距离最小的编辑序列
回溯获得编辑距离最小的编辑序列
"""
"""
self
.
align_seqs
=
[]
self
.
align_seqs
=
[]
i
=
oper_matrix
.
shape
[
0
]
-
1
i
=
oper_matrix
.
shape
[
0
]
-
1
j
=
oper_matrix
.
shape
[
1
]
-
1
j
=
oper_matrix
.
shape
[
1
]
-
1
if
abs
(
i
-
j
)
>
10
:
if
abs
(
i
-
j
)
>
10
:
self
.
_dfs
(
i
,
j
,
[],
oper_matrix
,
"first"
)
self
.
_dfs
(
i
,
j
,
[],
oper_matrix
,
"first"
)
else
:
else
:
self
.
_dfs
(
i
,
j
,
[],
oper_matrix
,
"all"
)
self
.
_dfs
(
i
,
j
,
[],
oper_matrix
,
"all"
)
final_align_seqs
=
[
seq
[::
-
1
]
for
seq
in
self
.
align_seqs
]
final_align_seqs
=
[
seq
[::
-
1
]
for
seq
in
self
.
align_seqs
]
return
final_align_seqs
return
final_align_seqs
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tokenizer
=
Tokenizer
(
"word"
)
tokenizer
=
Tokenizer
(
"word"
)
semantic_dict
,
semantic_class
=
read_cilin
()
semantic_dict
,
semantic_class
=
read_cilin
()
confusion_dict
=
read_confusion
()
confusion_dict
=
read_confusion
()
alignment
=
Alignment
(
semantic_dict
,
confusion_dict
)
alignment
=
Alignment
(
semantic_dict
,
confusion_dict
)
sents
=
[
"首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。"
.
replace
(
" "
,
""
),
"首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。"
.
replace
(
" "
,
""
)]
sents
=
[
"首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。"
.
replace
(
" "
,
""
),
"首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。"
.
replace
(
" "
,
""
)]
src
,
tgt
=
tokenizer
(
sents
)
src
,
tgt
=
tokenizer
(
sents
)
alignment
(
src
,
tgt
,
verbose
=
True
)
alignment
(
src
,
tgt
,
verbose
=
True
)
\ No newline at end of file
opencompass/datasets/lawbench/utils/modules/annotator.py
View file @
4dd9a3fc
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
modules.alignment
import
read_cilin
,
read_confusion
,
Alignment
from
modules.alignment
import
read_cilin
,
read_confusion
,
Alignment
from
modules.merger
import
Merger
from
modules.merger
import
Merger
from
modules.classifier
import
Classifier
from
modules.classifier
import
Classifier
class
Annotator
:
class
Annotator
:
def
__init__
(
self
,
def
__init__
(
self
,
align
:
Alignment
,
align
:
Alignment
,
merger
:
Merger
,
merger
:
Merger
,
classifier
:
Classifier
,
classifier
:
Classifier
,
granularity
:
str
=
"word"
,
granularity
:
str
=
"word"
,
strategy
:
str
=
"first"
):
strategy
:
str
=
"first"
):
self
.
align
=
align
self
.
align
=
align
self
.
merger
=
merger
self
.
merger
=
merger
self
.
classifier
=
classifier
self
.
classifier
=
classifier
self
.
granularity
=
granularity
self
.
granularity
=
granularity
self
.
strategy
=
strategy
self
.
strategy
=
strategy
@
classmethod
@
classmethod
def
create_default
(
cls
,
granularity
:
str
=
"word"
,
strategy
:
str
=
"first"
):
def
create_default
(
cls
,
granularity
:
str
=
"word"
,
strategy
:
str
=
"first"
):
"""
"""
Default parameters used in the paper
Default parameters used in the paper
"""
"""
semantic_dict
,
semantic_class
=
read_cilin
()
semantic_dict
,
semantic_class
=
read_cilin
()
confusion_dict
=
read_confusion
()
confusion_dict
=
read_confusion
()
align
=
Alignment
(
semantic_dict
,
confusion_dict
,
granularity
)
align
=
Alignment
(
semantic_dict
,
confusion_dict
,
granularity
)
merger
=
Merger
(
granularity
)
merger
=
Merger
(
granularity
)
classifier
=
Classifier
(
granularity
)
classifier
=
Classifier
(
granularity
)
return
cls
(
align
,
merger
,
classifier
,
granularity
,
strategy
)
return
cls
(
align
,
merger
,
classifier
,
granularity
,
strategy
)
def
__call__
(
self
,
def
__call__
(
self
,
src
:
List
[
Tuple
],
src
:
List
[
Tuple
],
tgt
:
List
[
Tuple
],
tgt
:
List
[
Tuple
],
annotator_id
:
int
=
0
,
annotator_id
:
int
=
0
,
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
"""
"""
Align sentences and annotate them with error type information
Align sentences and annotate them with error type information
"""
"""
src_tokens
=
[
x
[
0
]
for
x
in
src
]
src_tokens
=
[
x
[
0
]
for
x
in
src
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
src_str
=
""
.
join
(
src_tokens
)
src_str
=
""
.
join
(
src_tokens
)
tgt_str
=
""
.
join
(
tgt_tokens
)
tgt_str
=
""
.
join
(
tgt_tokens
)
# convert to text form
# convert to text form
annotations_out
=
[
"S "
+
" "
.
join
(
src_tokens
)
+
"
\n
"
]
annotations_out
=
[
"S "
+
" "
.
join
(
src_tokens
)
+
"
\n
"
]
if
tgt_str
==
"没有错误"
or
src_str
==
tgt_str
:
# Error Free Case
if
tgt_str
==
"没有错误"
or
src_str
==
tgt_str
:
# Error Free Case
annotations_out
.
append
(
f
"T
{
annotator_id
}
没有错误
\n
"
)
annotations_out
.
append
(
f
"T
{
annotator_id
}
没有错误
\n
"
)
cors
=
[
tgt_str
]
cors
=
[
tgt_str
]
op
,
toks
,
inds
=
"noop"
,
"-NONE-"
,
(
-
1
,
-
1
)
op
,
toks
,
inds
=
"noop"
,
"-NONE-"
,
(
-
1
,
-
1
)
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
annotations_out
.
append
(
a_str
)
annotations_out
.
append
(
a_str
)
elif
tgt_str
==
"无法标注"
:
# Not Annotatable Case
elif
tgt_str
==
"无法标注"
:
# Not Annotatable Case
annotations_out
.
append
(
f
"T
{
annotator_id
}
无法标注
\n
"
)
annotations_out
.
append
(
f
"T
{
annotator_id
}
无法标注
\n
"
)
cors
=
[
tgt_str
]
cors
=
[
tgt_str
]
op
,
toks
,
inds
=
"NA"
,
"-NONE-"
,
(
-
1
,
-
1
)
op
,
toks
,
inds
=
"NA"
,
"-NONE-"
,
(
-
1
,
-
1
)
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
annotations_out
.
append
(
a_str
)
annotations_out
.
append
(
a_str
)
else
:
# Other
else
:
# Other
align_objs
=
self
.
align
(
src
,
tgt
)
align_objs
=
self
.
align
(
src
,
tgt
)
edit_objs
=
[]
edit_objs
=
[]
align_idx
=
0
align_idx
=
0
if
self
.
strategy
==
"first"
:
if
self
.
strategy
==
"first"
:
align_objs
=
align_objs
[:
1
]
align_objs
=
align_objs
[:
1
]
for
align_obj
in
align_objs
:
for
align_obj
in
align_objs
:
edits
=
self
.
merger
(
align_obj
,
src
,
tgt
,
verbose
)
edits
=
self
.
merger
(
align_obj
,
src
,
tgt
,
verbose
)
if
edits
not
in
edit_objs
:
if
edits
not
in
edit_objs
:
edit_objs
.
append
(
edits
)
edit_objs
.
append
(
edits
)
annotations_out
.
append
(
f
"T
{
annotator_id
}
-A
{
align_idx
}
"
+
" "
.
join
(
tgt_tokens
)
+
"
\n
"
)
annotations_out
.
append
(
f
"T
{
annotator_id
}
-A
{
align_idx
}
"
+
" "
.
join
(
tgt_tokens
)
+
"
\n
"
)
align_idx
+=
1
align_idx
+=
1
cors
=
self
.
classifier
(
src
,
tgt
,
edits
,
verbose
)
cors
=
self
.
classifier
(
src
,
tgt
,
edits
,
verbose
)
# annotations_out = []
# annotations_out = []
for
cor
in
cors
:
for
cor
in
cors
:
op
,
toks
,
inds
=
cor
.
op
,
cor
.
toks
,
cor
.
inds
op
,
toks
,
inds
=
cor
.
op
,
cor
.
toks
,
cor
.
inds
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
annotations_out
.
append
(
a_str
)
annotations_out
.
append
(
a_str
)
annotations_out
.
append
(
"
\n
"
)
annotations_out
.
append
(
"
\n
"
)
return
annotations_out
,
cors
return
annotations_out
,
cors
opencompass/datasets/lawbench/utils/modules/classifier.py
View file @
4dd9a3fc
from
char_smi
import
CharFuncs
from
char_smi
import
CharFuncs
from
collections
import
namedtuple
from
collections
import
namedtuple
from
pypinyin
import
pinyin
,
Style
from
pypinyin
import
pinyin
,
Style
import
os
import
os
Correction
=
namedtuple
(
Correction
=
namedtuple
(
"Correction"
,
"Correction"
,
[
[
"op"
,
"op"
,
"toks"
,
"toks"
,
"inds"
,
"inds"
,
],
],
)
)
file_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
file_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
char_smi
=
CharFuncs
(
os
.
path
.
join
(
file_path
.
replace
(
"modules"
,
""
),
'data/char_meta.txt'
))
char_smi
=
CharFuncs
(
os
.
path
.
join
(
file_path
.
replace
(
"modules"
,
""
),
'data/char_meta.txt'
))
def
check_spell_error
(
src_span
:
str
,
def
check_spell_error
(
src_span
:
str
,
tgt_span
:
str
,
tgt_span
:
str
,
threshold
:
float
=
0.8
)
->
bool
:
threshold
:
float
=
0.8
)
->
bool
:
if
len
(
src_span
)
!=
len
(
tgt_span
):
if
len
(
src_span
)
!=
len
(
tgt_span
):
return
False
return
False
src_chars
=
[
ch
for
ch
in
src_span
]
src_chars
=
[
ch
for
ch
in
src_span
]
tgt_chars
=
[
ch
for
ch
in
tgt_span
]
tgt_chars
=
[
ch
for
ch
in
tgt_span
]
if
sorted
(
src_chars
)
==
sorted
(
tgt_chars
):
# 词内部字符异位
if
sorted
(
src_chars
)
==
sorted
(
tgt_chars
):
# 词内部字符异位
return
True
return
True
for
src_char
,
tgt_char
in
zip
(
src_chars
,
tgt_chars
):
for
src_char
,
tgt_char
in
zip
(
src_chars
,
tgt_chars
):
if
src_char
!=
tgt_char
:
if
src_char
!=
tgt_char
:
if
src_char
not
in
char_smi
.
data
or
tgt_char
not
in
char_smi
.
data
:
if
src_char
not
in
char_smi
.
data
or
tgt_char
not
in
char_smi
.
data
:
return
False
return
False
v_sim
=
char_smi
.
shape_similarity
(
src_char
,
tgt_char
)
v_sim
=
char_smi
.
shape_similarity
(
src_char
,
tgt_char
)
p_sim
=
char_smi
.
pronunciation_similarity
(
src_char
,
tgt_char
)
p_sim
=
char_smi
.
pronunciation_similarity
(
src_char
,
tgt_char
)
if
v_sim
+
p_sim
<
threshold
and
not
(
if
v_sim
+
p_sim
<
threshold
and
not
(
set
(
pinyin
(
src_char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])
&
set
(
pinyin
(
tgt_char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])):
set
(
pinyin
(
src_char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])
&
set
(
pinyin
(
tgt_char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])):
return
False
return
False
return
True
return
True
class
Classifier
:
class
Classifier
:
"""
"""
错误类型分类器
错误类型分类器
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
granularity
:
str
=
"word"
):
granularity
:
str
=
"word"
):
self
.
granularity
=
granularity
self
.
granularity
=
granularity
@
staticmethod
@
staticmethod
def
get_pos_type
(
pos
):
def
get_pos_type
(
pos
):
if
pos
in
{
"n"
,
"nd"
}:
if
pos
in
{
"n"
,
"nd"
}:
return
"NOUN"
return
"NOUN"
if
pos
in
{
"nh"
,
"ni"
,
"nl"
,
"ns"
,
"nt"
,
"nz"
}:
if
pos
in
{
"nh"
,
"ni"
,
"nl"
,
"ns"
,
"nt"
,
"nz"
}:
return
"NOUN-NE"
return
"NOUN-NE"
if
pos
in
{
"v"
}:
if
pos
in
{
"v"
}:
return
"VERB"
return
"VERB"
if
pos
in
{
"a"
,
"b"
}:
if
pos
in
{
"a"
,
"b"
}:
return
"ADJ"
return
"ADJ"
if
pos
in
{
"c"
}:
if
pos
in
{
"c"
}:
return
"CONJ"
return
"CONJ"
if
pos
in
{
"r"
}:
if
pos
in
{
"r"
}:
return
"PRON"
return
"PRON"
if
pos
in
{
"d"
}:
if
pos
in
{
"d"
}:
return
"ADV"
return
"ADV"
if
pos
in
{
"u"
}:
if
pos
in
{
"u"
}:
return
"AUX"
return
"AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX"
# return "SUFFIX"
if
pos
in
{
"m"
}:
if
pos
in
{
"m"
}:
return
"NUM"
return
"NUM"
if
pos
in
{
"p"
}:
if
pos
in
{
"p"
}:
return
"PREP"
return
"PREP"
if
pos
in
{
"q"
}:
if
pos
in
{
"q"
}:
return
"QUAN"
return
"QUAN"
if
pos
in
{
"wp"
}:
if
pos
in
{
"wp"
}:
return
"PUNCT"
return
"PUNCT"
return
"OTHER"
return
"OTHER"
def
__call__
(
self
,
def
__call__
(
self
,
src
,
src
,
tgt
,
tgt
,
edits
,
edits
,
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
"""
"""
为编辑操作划分错误类型
为编辑操作划分错误类型
:param src: 错误句子信息
:param src: 错误句子信息
:param tgt: 正确句子信息
:param tgt: 正确句子信息
:param edits: 编辑操作
:param edits: 编辑操作
:param verbose: 是否打印信息
:param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作
:return: 划分完错误类型后的编辑操作
"""
"""
results
=
[]
results
=
[]
src_tokens
=
[
x
[
0
]
for
x
in
src
]
src_tokens
=
[
x
[
0
]
for
x
in
src
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
for
edit
in
edits
:
for
edit
in
edits
:
error_type
=
edit
[
0
]
error_type
=
edit
[
0
]
src_span
=
" "
.
join
(
src_tokens
[
edit
[
1
]:
edit
[
2
]])
src_span
=
" "
.
join
(
src_tokens
[
edit
[
1
]:
edit
[
2
]])
tgt_span
=
" "
.
join
(
tgt_tokens
[
edit
[
3
]:
edit
[
4
]])
tgt_span
=
" "
.
join
(
tgt_tokens
[
edit
[
3
]:
edit
[
4
]])
# print(tgt_span)
# print(tgt_span)
cor
=
None
cor
=
None
if
error_type
[
0
]
==
"T"
:
if
error_type
[
0
]
==
"T"
:
cor
=
Correction
(
"W"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"W"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
elif
error_type
[
0
]
==
"D"
:
elif
error_type
[
0
]
==
"D"
:
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
edit
[
2
]
-
edit
[
1
]
>
1
:
# 词组冗余暂时分为OTHER
if
edit
[
2
]
-
edit
[
1
]
>
1
:
# 词组冗余暂时分为OTHER
cor
=
Correction
(
"R:OTHER"
,
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"R:OTHER"
,
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
else
:
else
:
pos
=
self
.
get_pos_type
(
src
[
edit
[
1
]][
1
])
pos
=
self
.
get_pos_type
(
src
[
edit
[
1
]][
1
])
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
cor
=
Correction
(
"R:{:s}"
.
format
(
pos
),
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"R:{:s}"
.
format
(
pos
),
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
else
:
# 字级别可以只需要根据操作划分类型即可
else
:
# 字级别可以只需要根据操作划分类型即可
cor
=
Correction
(
"R"
,
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"R"
,
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
elif
error_type
[
0
]
==
"I"
:
elif
error_type
[
0
]
==
"I"
:
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
edit
[
4
]
-
edit
[
3
]
>
1
:
# 词组丢失暂时分为OTHER
if
edit
[
4
]
-
edit
[
3
]
>
1
:
# 词组丢失暂时分为OTHER
cor
=
Correction
(
"M:OTHER"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"M:OTHER"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
else
:
pos
=
self
.
get_pos_type
(
tgt
[
edit
[
3
]][
1
])
pos
=
self
.
get_pos_type
(
tgt
[
edit
[
3
]][
1
])
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
cor
=
Correction
(
"M:{:s}"
.
format
(
pos
),
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"M:{:s}"
.
format
(
pos
),
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
# 字级别可以只需要根据操作划分类型即可
else
:
# 字级别可以只需要根据操作划分类型即可
cor
=
Correction
(
"M"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"M"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
elif
error_type
[
0
]
==
"S"
:
elif
error_type
[
0
]
==
"S"
:
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
check_spell_error
(
src_span
.
replace
(
" "
,
""
),
tgt_span
.
replace
(
" "
,
""
)):
if
check_spell_error
(
src_span
.
replace
(
" "
,
""
),
tgt_span
.
replace
(
" "
,
""
)):
cor
=
Correction
(
"S:SPELL"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"S:SPELL"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
# Todo 暂且不单独区分命名实体拼写错误
# Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1:
# if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else:
# else:
# pos = self.get_pos_type(tgt[edit[3]][1])
# pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误
# if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误
# else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else
:
else
:
if
edit
[
4
]
-
edit
[
3
]
>
1
:
# 词组被替换暂时分为OTHER
if
edit
[
4
]
-
edit
[
3
]
>
1
:
# 词组被替换暂时分为OTHER
cor
=
Correction
(
"S:OTHER"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"S:OTHER"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
else
:
pos
=
self
.
get_pos_type
(
tgt
[
edit
[
3
]][
1
])
pos
=
self
.
get_pos_type
(
tgt
[
edit
[
3
]][
1
])
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
cor
=
Correction
(
"S:{:s}"
.
format
(
pos
),
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"S:{:s}"
.
format
(
pos
),
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
# 字级别可以只需要根据操作划分类型即可
else
:
# 字级别可以只需要根据操作划分类型即可
cor
=
Correction
(
"S"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
cor
=
Correction
(
"S"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
results
.
append
(
cor
)
results
.
append
(
cor
)
if
verbose
:
if
verbose
:
print
(
"========== Corrections =========="
)
print
(
"========== Corrections =========="
)
for
cor
in
results
:
for
cor
in
results
:
print
(
"Type: {:s}, Position: {:d} -> {:d}, Target: {:s}"
.
format
(
cor
.
op
,
cor
.
inds
[
0
],
cor
.
inds
[
1
],
cor
.
toks
))
print
(
"Type: {:s}, Position: {:d} -> {:d}, Target: {:s}"
.
format
(
cor
.
op
,
cor
.
inds
[
0
],
cor
.
inds
[
1
],
cor
.
toks
))
return
results
return
results
# print(pinyin("朝", style=Style.NORMAL))
# print(pinyin("朝", style=Style.NORMAL))
opencompass/datasets/lawbench/utils/modules/merger.py
View file @
4dd9a3fc
from
itertools
import
groupby
from
itertools
import
groupby
from
string
import
punctuation
from
string
import
punctuation
from
typing
import
List
from
typing
import
List
from
modules.tokenizer
import
Tokenizer
from
modules.tokenizer
import
Tokenizer
from
modules.alignment
import
Alignment
,
read_cilin
,
read_confusion
from
modules.alignment
import
Alignment
,
read_cilin
,
read_confusion
import
Levenshtein
import
Levenshtein
class
Merger
:
class
Merger
:
"""
"""
合并编辑操作,从Token-Level转换为Span-Level
合并编辑操作,从Token-Level转换为Span-Level
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
granularity
:
str
=
"word"
,
granularity
:
str
=
"word"
,
merge
:
bool
=
False
):
merge
:
bool
=
False
):
chinese_punct
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
chinese_punct
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self
.
punctuation
=
punctuation
+
chinese_punct
self
.
punctuation
=
punctuation
+
chinese_punct
self
.
not_merge_token
=
[
punct
for
punct
in
self
.
punctuation
]
self
.
not_merge_token
=
[
punct
for
punct
in
self
.
punctuation
]
self
.
granularity
=
granularity
self
.
granularity
=
granularity
self
.
merge
=
merge
self
.
merge
=
merge
@
staticmethod
@
staticmethod
def
_merge_edits
(
seq
,
tag
=
"X"
):
def
_merge_edits
(
seq
,
tag
=
"X"
):
if
seq
:
if
seq
:
return
[(
tag
,
seq
[
0
][
1
],
seq
[
-
1
][
2
],
seq
[
0
][
3
],
seq
[
-
1
][
4
])]
return
[(
tag
,
seq
[
0
][
1
],
seq
[
-
1
][
2
],
seq
[
0
][
3
],
seq
[
-
1
][
4
])]
else
:
else
:
return
seq
return
seq
@
staticmethod
@
staticmethod
def
_check_revolve
(
span_a
,
span_b
):
def
_check_revolve
(
span_a
,
span_b
):
span_a
=
span_a
+
span_a
span_a
=
span_a
+
span_a
return
span_b
in
span_a
return
span_b
in
span_a
def
_process_seq
(
self
,
seq
,
src_tokens
,
tgt_tokens
):
def
_process_seq
(
self
,
seq
,
src_tokens
,
tgt_tokens
):
if
len
(
seq
)
<=
1
:
if
len
(
seq
)
<=
1
:
return
seq
return
seq
ops
=
[
op
[
0
]
for
op
in
seq
]
ops
=
[
op
[
0
]
for
op
in
seq
]
if
set
(
ops
)
==
{
"D"
}
or
set
(
ops
)
==
{
"I"
}:
if
set
(
ops
)
==
{
"D"
}
or
set
(
ops
)
==
{
"I"
}:
return
self
.
_merge_edits
(
seq
,
set
(
ops
).
pop
())
return
self
.
_merge_edits
(
seq
,
set
(
ops
).
pop
())
if
set
(
ops
)
==
{
"D"
,
"I"
}
or
set
(
ops
)
==
{
"I"
,
"D"
}:
if
set
(
ops
)
==
{
"D"
,
"I"
}
or
set
(
ops
)
==
{
"I"
,
"D"
}:
# do not merge this pattern_from_qua.txt
# do not merge this pattern_from_qua.txt
return
seq
return
seq
if
set
(
ops
)
==
{
"S"
}:
if
set
(
ops
)
==
{
"S"
}:
if
self
.
granularity
==
"word"
:
if
self
.
granularity
==
"word"
:
return
seq
return
seq
else
:
else
:
return
self
.
_merge_edits
(
seq
,
"S"
)
return
self
.
_merge_edits
(
seq
,
"S"
)
if
set
(
ops
)
==
{
"M"
}:
if
set
(
ops
)
==
{
"M"
}:
return
self
.
_merge_edits
(
seq
,
"M"
)
return
self
.
_merge_edits
(
seq
,
"M"
)
return
self
.
_merge_edits
(
seq
,
"S"
)
return
self
.
_merge_edits
(
seq
,
"S"
)
def
__call__
(
self
,
def
__call__
(
self
,
align_obj
,
align_obj
,
src
:
List
,
src
:
List
,
tgt
:
List
,
tgt
:
List
,
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
"""
"""
Based on ERRANT's merge, adapted for Chinese
Based on ERRANT's merge, adapted for Chinese
"""
"""
src_tokens
=
[
x
[
0
]
for
x
in
src
]
src_tokens
=
[
x
[
0
]
for
x
in
src
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
edits
=
[]
edits
=
[]
# Split alignment into groups of M, T and rest. (T has a number after it)
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for
op
,
group
in
groupby
(
for
op
,
group
in
groupby
(
align_obj
,
align_obj
,
lambda
x
:
x
[
0
][
0
]
if
x
[
0
][
0
]
in
{
"M"
,
"T"
}
else
False
,
lambda
x
:
x
[
0
][
0
]
if
x
[
0
][
0
]
in
{
"M"
,
"T"
}
else
False
,
):
):
group
=
list
(
group
)
group
=
list
(
group
)
# T is always split TODO: Evaluate this
# T is always split TODO: Evaluate this
if
op
==
"T"
:
if
op
==
"T"
:
for
seq
in
group
:
for
seq
in
group
:
edits
.
append
(
seq
)
edits
.
append
(
seq
)
# Process D, I and S subsequence
# Process D, I and S subsequence
else
:
else
:
# Turn the processed sequence into edits
# Turn the processed sequence into edits
processed
=
self
.
_process_seq
(
group
,
src_tokens
,
tgt_tokens
)
processed
=
self
.
_process_seq
(
group
,
src_tokens
,
tgt_tokens
)
for
seq
in
processed
:
for
seq
in
processed
:
edits
.
append
(
seq
)
edits
.
append
(
seq
)
filtered_edits
=
[]
filtered_edits
=
[]
i
=
0
i
=
0
while
i
<
len
(
edits
):
while
i
<
len
(
edits
):
e1
=
edits
[
i
][
0
][
0
]
e1
=
edits
[
i
][
0
][
0
]
if
i
<
len
(
edits
)
-
2
:
if
i
<
len
(
edits
)
-
2
:
e2
=
edits
[
i
+
1
][
0
][
0
]
e2
=
edits
[
i
+
1
][
0
][
0
]
e3
=
edits
[
i
+
2
][
0
][
0
]
e3
=
edits
[
i
+
2
][
0
][
0
]
# Find "S M S" patterns
# Find "S M S" patterns
# Ex:
# Ex:
# S M S
# S M S
# 冬阴功 对 外国人
# 冬阴功 对 外国人
# 外国人 对 冬阴功
# 外国人 对 冬阴功
if
e1
==
"S"
and
e2
==
"M"
and
e3
==
"S"
:
if
e1
==
"S"
and
e2
==
"M"
and
e3
==
"S"
:
w1
=
""
.
join
(
src_tokens
[
edits
[
i
][
1
]:
edits
[
i
][
2
]])
w1
=
""
.
join
(
src_tokens
[
edits
[
i
][
1
]:
edits
[
i
][
2
]])
w2
=
""
.
join
(
tgt_tokens
[
edits
[
i
][
3
]:
edits
[
i
][
4
]])
w2
=
""
.
join
(
tgt_tokens
[
edits
[
i
][
3
]:
edits
[
i
][
4
]])
w3
=
""
.
join
(
src_tokens
[
edits
[
i
+
2
][
1
]:
edits
[
i
+
2
][
2
]])
w3
=
""
.
join
(
src_tokens
[
edits
[
i
+
2
][
1
]:
edits
[
i
+
2
][
2
]])
w4
=
""
.
join
(
tgt_tokens
[
edits
[
i
+
2
][
3
]:
edits
[
i
+
2
][
4
]])
w4
=
""
.
join
(
tgt_tokens
[
edits
[
i
+
2
][
3
]:
edits
[
i
+
2
][
4
]])
if
min
([
len
(
w1
),
len
(
w2
),
len
(
w3
),
len
(
w4
)])
==
1
:
if
min
([
len
(
w1
),
len
(
w2
),
len
(
w3
),
len
(
w4
)])
==
1
:
if
w1
==
w4
and
w2
==
w3
:
if
w1
==
w4
and
w2
==
w3
:
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
filtered_edits
.
append
(
seq
)
i
+=
3
i
+=
3
else
:
else
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
else
:
else
:
if
Levenshtein
.
distance
(
w1
,
w4
)
<=
1
and
Levenshtein
.
distance
(
w2
,
w3
)
<=
1
:
if
Levenshtein
.
distance
(
w1
,
w4
)
<=
1
and
Levenshtein
.
distance
(
w2
,
w3
)
<=
1
:
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
filtered_edits
.
append
(
seq
)
i
+=
3
i
+=
3
else
:
else
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
# Find "D M I" or "I M D" patterns
# Find "D M I" or "I M D" patterns
# Ex:
# Ex:
# D M I
# D M I
# 旅游 去 陌生 的 地方
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
# 去 陌生 的 地方 旅游
elif
(
e1
==
"D"
and
(
e2
==
"M"
or
e2
.
startswith
(
"T"
))
and
e3
==
"I"
)
or
(
e1
==
"I"
and
(
e2
==
"M"
or
e2
.
startswith
(
"T"
))
and
e3
==
"D"
):
elif
(
e1
==
"D"
and
(
e2
==
"M"
or
e2
.
startswith
(
"T"
))
and
e3
==
"I"
)
or
(
e1
==
"I"
and
(
e2
==
"M"
or
e2
.
startswith
(
"T"
))
and
e3
==
"D"
):
if
e1
==
"D"
:
if
e1
==
"D"
:
delete_token
=
src_tokens
[
edits
[
i
][
1
]:
edits
[
i
][
2
]]
delete_token
=
src_tokens
[
edits
[
i
][
1
]:
edits
[
i
][
2
]]
insert_token
=
tgt_tokens
[
edits
[
i
+
2
][
3
]:
edits
[
i
+
2
][
4
]]
insert_token
=
tgt_tokens
[
edits
[
i
+
2
][
3
]:
edits
[
i
+
2
][
4
]]
else
:
else
:
delete_token
=
src_tokens
[
edits
[
i
+
2
][
1
]:
edits
[
i
+
2
][
2
]]
delete_token
=
src_tokens
[
edits
[
i
+
2
][
1
]:
edits
[
i
+
2
][
2
]]
insert_token
=
tgt_tokens
[
edits
[
i
][
3
]:
edits
[
i
][
4
]]
insert_token
=
tgt_tokens
[
edits
[
i
][
3
]:
edits
[
i
][
4
]]
a
,
b
=
""
.
join
(
delete_token
),
""
.
join
(
insert_token
)
a
,
b
=
""
.
join
(
delete_token
),
""
.
join
(
insert_token
)
if
len
(
a
)
<
len
(
b
):
if
len
(
a
)
<
len
(
b
):
a
,
b
=
b
,
a
a
,
b
=
b
,
a
if
a
not
in
self
.
punctuation
and
b
not
in
self
.
punctuation
and
len
(
a
)
-
len
(
b
)
<=
1
:
if
a
not
in
self
.
punctuation
and
b
not
in
self
.
punctuation
and
len
(
a
)
-
len
(
b
)
<=
1
:
if
len
(
b
)
==
1
:
if
len
(
b
)
==
1
:
if
a
==
b
:
if
a
==
b
:
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
filtered_edits
.
append
(
seq
)
i
+=
3
i
+=
3
else
:
else
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
else
:
else
:
if
Levenshtein
.
distance
(
a
,
b
)
<=
1
or
(
len
(
a
)
==
len
(
b
)
and
self
.
_check_revolve
(
a
,
b
)):
if
Levenshtein
.
distance
(
a
,
b
)
<=
1
or
(
len
(
a
)
==
len
(
b
)
and
self
.
_check_revolve
(
a
,
b
)):
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
filtered_edits
.
append
(
seq
)
i
+=
3
i
+=
3
else
:
else
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
else
:
else
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
else
:
else
:
if
e1
!=
"M"
:
if
e1
!=
"M"
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
else
:
else
:
if
e1
!=
"M"
:
if
e1
!=
"M"
:
filtered_edits
.
append
(
edits
[
i
])
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
i
+=
1
# In rare cases with word-level tokenization, the following error can occur:
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# M D S M
# 有 時 住 上層
# 有 時 住 上層
# 有 時住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# Which results in S: 時住 --> 時住
# We need to filter this case out
# We need to filter this case out
second_filter
=
[]
second_filter
=
[]
for
edit
in
filtered_edits
:
# 避免因为分词错误导致的mismatch现象
for
edit
in
filtered_edits
:
# 避免因为分词错误导致的mismatch现象
span1
=
""
.
join
(
src_tokens
[
edit
[
1
]
:
edit
[
2
]])
span1
=
""
.
join
(
src_tokens
[
edit
[
1
]
:
edit
[
2
]])
span2
=
""
.
join
(
tgt_tokens
[
edit
[
3
]
:
edit
[
4
]])
span2
=
""
.
join
(
tgt_tokens
[
edit
[
3
]
:
edit
[
4
]])
if
span1
!=
span2
:
if
span1
!=
span2
:
if
edit
[
0
]
==
"S"
:
if
edit
[
0
]
==
"S"
:
b
=
True
b
=
True
# In rare cases with word-level tokenization, the following error can occur:
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# S I I M
# 负责任 老师
# 负责任 老师
# 负 责任 的 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# We need to convert this edit to I: --> 的
# 首部有重叠
# 首部有重叠
common_str
=
""
common_str
=
""
tmp_new_start_1
=
edit
[
1
]
tmp_new_start_1
=
edit
[
1
]
for
i
in
range
(
edit
[
1
],
edit
[
2
]):
for
i
in
range
(
edit
[
1
],
edit
[
2
]):
if
not
span2
.
startswith
(
common_str
+
src_tokens
[
i
]):
if
not
span2
.
startswith
(
common_str
+
src_tokens
[
i
]):
break
break
common_str
+=
src_tokens
[
i
]
common_str
+=
src_tokens
[
i
]
tmp_new_start_1
=
i
+
1
tmp_new_start_1
=
i
+
1
new_start_1
,
new_start_2
=
edit
[
1
],
edit
[
3
]
new_start_1
,
new_start_2
=
edit
[
1
],
edit
[
3
]
if
common_str
:
if
common_str
:
tmp_str
=
""
tmp_str
=
""
for
i
in
range
(
edit
[
3
],
edit
[
4
]):
for
i
in
range
(
edit
[
3
],
edit
[
4
]):
tmp_str
+=
tgt_tokens
[
i
]
tmp_str
+=
tgt_tokens
[
i
]
if
tmp_str
==
common_str
:
if
tmp_str
==
common_str
:
new_start_1
,
new_start_2
=
tmp_new_start_1
,
i
+
1
new_start_1
,
new_start_2
=
tmp_new_start_1
,
i
+
1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b
=
False
b
=
False
break
break
elif
len
(
tmp_str
)
>
len
(
common_str
):
elif
len
(
tmp_str
)
>
len
(
common_str
):
break
break
# 尾部有重叠
# 尾部有重叠
common_str
=
""
common_str
=
""
new_end_1
,
new_end_2
=
edit
[
2
],
edit
[
4
]
new_end_1
,
new_end_2
=
edit
[
2
],
edit
[
4
]
tmp_new_end_1
=
edit
[
2
]
tmp_new_end_1
=
edit
[
2
]
for
i
in
reversed
(
range
(
new_start_1
,
edit
[
2
])):
for
i
in
reversed
(
range
(
new_start_1
,
edit
[
2
])):
if
not
span2
.
endswith
(
src_tokens
[
i
]
+
common_str
):
if
not
span2
.
endswith
(
src_tokens
[
i
]
+
common_str
):
break
break
common_str
=
src_tokens
[
i
]
+
common_str
common_str
=
src_tokens
[
i
]
+
common_str
tmp_new_end_1
=
i
tmp_new_end_1
=
i
if
common_str
:
if
common_str
:
tmp_str
=
""
tmp_str
=
""
for
i
in
reversed
(
range
(
new_start_2
,
edit
[
4
])):
for
i
in
reversed
(
range
(
new_start_2
,
edit
[
4
])):
tmp_str
=
tgt_tokens
[
i
]
+
tmp_str
tmp_str
=
tgt_tokens
[
i
]
+
tmp_str
if
tmp_str
==
common_str
:
if
tmp_str
==
common_str
:
new_end_1
,
new_end_2
=
tmp_new_end_1
,
i
new_end_1
,
new_end_2
=
tmp_new_end_1
,
i
b
=
False
b
=
False
break
break
elif
len
(
tmp_str
)
>
len
(
common_str
):
elif
len
(
tmp_str
)
>
len
(
common_str
):
break
break
if
b
:
if
b
:
second_filter
.
append
(
edit
)
second_filter
.
append
(
edit
)
else
:
else
:
if
new_start_1
==
new_end_1
:
if
new_start_1
==
new_end_1
:
new_edit
=
(
"I"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
new_edit
=
(
"I"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
elif
new_start_2
==
new_end_2
:
elif
new_start_2
==
new_end_2
:
new_edit
=
(
"D"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
new_edit
=
(
"D"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
else
:
else
:
new_edit
=
(
"S"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
new_edit
=
(
"S"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
second_filter
.
append
(
new_edit
)
second_filter
.
append
(
new_edit
)
else
:
else
:
second_filter
.
append
(
edit
)
second_filter
.
append
(
edit
)
if
verbose
:
if
verbose
:
print
(
"========== Parallels =========="
)
print
(
"========== Parallels =========="
)
print
(
""
.
join
(
src_tokens
))
print
(
""
.
join
(
src_tokens
))
print
(
""
.
join
(
tgt_tokens
))
print
(
""
.
join
(
tgt_tokens
))
print
(
"========== Results =========="
)
print
(
"========== Results =========="
)
for
edit
in
second_filter
:
for
edit
in
second_filter
:
op
=
edit
[
0
]
op
=
edit
[
0
]
s
=
" "
.
join
(
src_tokens
[
edit
[
1
]:
edit
[
2
]])
s
=
" "
.
join
(
src_tokens
[
edit
[
1
]:
edit
[
2
]])
t
=
" "
.
join
(
tgt_tokens
[
edit
[
3
]:
edit
[
4
]])
t
=
" "
.
join
(
tgt_tokens
[
edit
[
3
]:
edit
[
4
]])
print
(
f
"
{
op
}
:
\t
{
s
}
\t
-->
\t
{
t
}
"
)
print
(
f
"
{
op
}
:
\t
{
s
}
\t
-->
\t
{
t
}
"
)
print
(
"========== Infos =========="
)
print
(
"========== Infos =========="
)
print
(
str
(
src
))
print
(
str
(
src
))
print
(
str
(
tgt
))
print
(
str
(
tgt
))
return
second_filter
return
second_filter
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tokenizer
=
Tokenizer
(
"char"
)
tokenizer
=
Tokenizer
(
"char"
)
semantic_dict
,
semantic_class
=
read_cilin
()
semantic_dict
,
semantic_class
=
read_cilin
()
confusion_dict
=
read_confusion
()
confusion_dict
=
read_confusion
()
alignment
=
Alignment
(
semantic_dict
,
confusion_dict
)
alignment
=
Alignment
(
semantic_dict
,
confusion_dict
)
sents
=
[
sents
=
[
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。"
.
replace
(
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。"
.
replace
(
" "
,
""
),
" "
,
""
),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。"
.
replace
(
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。"
.
replace
(
" "
,
""
)]
" "
,
""
)]
src
,
tgt
=
tokenizer
(
sents
)
src
,
tgt
=
tokenizer
(
sents
)
align_obj
=
alignment
(
src
,
tgt
)
align_obj
=
alignment
(
src
,
tgt
)
m
=
Merger
()
m
=
Merger
()
m
(
align_obj
,
src
,
tgt
,
verbose
=
True
)
m
(
align_obj
,
src
,
tgt
,
verbose
=
True
)
\ No newline at end of file
opencompass/datasets/lawbench/utils/modules/tokenizer.py
View file @
4dd9a3fc
from
ltp
import
LTP
from
ltp
import
LTP
from
typing
import
List
from
typing
import
List
from
pypinyin
import
pinyin
,
Style
,
lazy_pinyin
from
pypinyin
import
pinyin
,
Style
,
lazy_pinyin
import
torch
import
torch
import
os
import
os
import
functools
import
functools
class
Tokenizer
:
class
Tokenizer
:
"""
"""
分词器
分词器
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
granularity
:
str
=
"word"
,
granularity
:
str
=
"word"
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
segmented
:
bool
=
False
,
segmented
:
bool
=
False
,
bpe
:
bool
=
False
,
bpe
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""
"""
构造函数
构造函数
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
"""
"""
self
.
ltp
=
None
self
.
ltp
=
None
if
granularity
==
"word"
:
if
granularity
==
"word"
:
self
.
ltp
=
LTP
(
device
=
torch
.
device
(
device
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
))
self
.
ltp
=
LTP
(
device
=
torch
.
device
(
device
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
))
self
.
ltp
.
add_words
(
words
=
[
"[缺失成分]"
],
max_window
=
6
)
self
.
ltp
.
add_words
(
words
=
[
"[缺失成分]"
],
max_window
=
6
)
self
.
segmented
=
segmented
self
.
segmented
=
segmented
self
.
granularity
=
granularity
self
.
granularity
=
granularity
if
self
.
granularity
==
"word"
:
if
self
.
granularity
==
"word"
:
self
.
tokenizer
=
self
.
split_word
self
.
tokenizer
=
self
.
split_word
elif
self
.
granularity
==
"char"
:
elif
self
.
granularity
==
"char"
:
self
.
tokenizer
=
functools
.
partial
(
self
.
split_char
,
bpe
=
bpe
)
self
.
tokenizer
=
functools
.
partial
(
self
.
split_char
,
bpe
=
bpe
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
"{:s}
\n
Mode:{:s}
\n
}"
.
format
(
str
(
self
.
__class__
.
__name__
),
self
.
mode
)
return
"{:s}
\n
Mode:{:s}
\n
}"
.
format
(
str
(
self
.
__class__
.
__name__
),
self
.
mode
)
def
__call__
(
self
,
def
__call__
(
self
,
input_strings
:
List
[
str
]
input_strings
:
List
[
str
]
)
->
List
:
)
->
List
:
"""
"""
分词函数
分词函数
:param input_strings: 需要分词的字符串列表
:param input_strings: 需要分词的字符串列表
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
"""
"""
if
not
self
.
segmented
:
if
not
self
.
segmented
:
input_strings
=
[
""
.
join
(
s
.
split
(
" "
))
for
s
in
input_strings
]
input_strings
=
[
""
.
join
(
s
.
split
(
" "
))
for
s
in
input_strings
]
results
=
self
.
tokenizer
(
input_strings
)
results
=
self
.
tokenizer
(
input_strings
)
return
results
return
results
def
split_char
(
self
,
input_strings
:
List
[
str
],
bpe
=
False
)
->
List
:
def
split_char
(
self
,
input_strings
:
List
[
str
],
bpe
=
False
)
->
List
:
"""
"""
分字函数
分字函数
:param input_strings: 需要分字的字符串
:param input_strings: 需要分字的字符串
:return: 分字结果
:return: 分字结果
"""
"""
if
bpe
:
if
bpe
:
from
.
import
tokenization
from
.
import
tokenization
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
os
.
path
.
join
(
project_dir
,
"data"
,
"chinese_vocab.txt"
),
do_lower_case
=
False
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
os
.
path
.
join
(
project_dir
,
"data"
,
"chinese_vocab.txt"
),
do_lower_case
=
False
)
results
=
[]
results
=
[]
for
input_string
in
input_strings
:
for
input_string
in
input_strings
:
if
not
self
.
segmented
:
# 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
if
not
self
.
segmented
:
# 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
segment_string
=
" "
.
join
([
char
for
char
in
input_string
]
if
not
bpe
else
tokenizer
.
tokenize
(
input_string
))
segment_string
=
" "
.
join
([
char
for
char
in
input_string
]
if
not
bpe
else
tokenizer
.
tokenize
(
input_string
))
else
:
else
:
segment_string
=
input_string
segment_string
=
input_string
# print(segment_string)
# print(segment_string)
segment_string
=
segment_string
.
replace
(
"[ 缺 失 成 分 ]"
,
"[缺失成分]"
).
split
(
" "
)
# 缺失成分当成一个单独的token
segment_string
=
segment_string
.
replace
(
"[ 缺 失 成 分 ]"
,
"[缺失成分]"
).
split
(
" "
)
# 缺失成分当成一个单独的token
results
.
append
([(
char
,
"unk"
,
pinyin
(
char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])
for
char
in
segment_string
])
results
.
append
([(
char
,
"unk"
,
pinyin
(
char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])
for
char
in
segment_string
])
return
results
return
results
def
split_word
(
self
,
input_strings
:
List
[
str
])
->
List
:
def
split_word
(
self
,
input_strings
:
List
[
str
])
->
List
:
"""
"""
分词函数
分词函数
:param input_strings: 需要分词的字符串
:param input_strings: 需要分词的字符串
:return: 分词结果
:return: 分词结果
"""
"""
if
self
.
segmented
:
if
self
.
segmented
:
seg
,
hidden
=
self
.
ltp
.
seg
([
input_string
.
split
(
" "
)
for
input_string
in
input_strings
],
is_preseged
=
True
)
seg
,
hidden
=
self
.
ltp
.
seg
([
input_string
.
split
(
" "
)
for
input_string
in
input_strings
],
is_preseged
=
True
)
else
:
else
:
seg
,
hidden
=
self
.
ltp
.
seg
(
input_strings
)
seg
,
hidden
=
self
.
ltp
.
seg
(
input_strings
)
pos
=
self
.
ltp
.
pos
(
hidden
)
pos
=
self
.
ltp
.
pos
(
hidden
)
result
=
[]
result
=
[]
for
s
,
p
in
zip
(
seg
,
pos
):
for
s
,
p
in
zip
(
seg
,
pos
):
pinyin
=
[
lazy_pinyin
(
word
)
for
word
in
s
]
pinyin
=
[
lazy_pinyin
(
word
)
for
word
in
s
]
result
.
append
(
list
(
zip
(
s
,
p
,
pinyin
)))
result
.
append
(
list
(
zip
(
s
,
p
,
pinyin
)))
return
result
return
result
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tokenizer
=
Tokenizer
(
"word"
)
tokenizer
=
Tokenizer
(
"word"
)
print
(
tokenizer
([
"LAC是个优秀的分词工具"
,
"百度是一家高科技公司"
]))
print
(
tokenizer
([
"LAC是个优秀的分词工具"
,
"百度是一家高科技公司"
]))
opencompass/datasets/lawbench/utils/parallel_to_m2.py
View file @
4dd9a3fc
import
os
import
os
from
modules.annotator
import
Annotator
from
modules.annotator
import
Annotator
from
modules.tokenizer
import
Tokenizer
from
modules.tokenizer
import
Tokenizer
import
argparse
import
argparse
from
collections
import
Counter
from
collections
import
Counter
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
torch
import
torch
from
collections
import
defaultdict
from
collections
import
defaultdict
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
from
opencc
import
OpenCC
from
opencc
import
OpenCC
import
timeout_decorator
import
timeout_decorator
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
annotator
,
sentence_to_tokenized
=
None
,
None
annotator
,
sentence_to_tokenized
=
None
,
None
cc
=
OpenCC
(
"t2s"
)
cc
=
OpenCC
(
"t2s"
)
@
timeout_decorator
.
timeout
(
10
)
@
timeout_decorator
.
timeout
(
10
)
def
annotate_with_time_out
(
line
):
def
annotate_with_time_out
(
line
):
"""
"""
:param line:
:param line:
:return:
:return:
"""
"""
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
source
=
sent_list
[
0
]
source
=
sent_list
[
0
]
if
args
.
segmented
:
if
args
.
segmented
:
source
=
source
.
strip
()
source
=
source
.
strip
()
else
:
else
:
source
=
""
.
join
(
source
.
strip
().
split
())
source
=
""
.
join
(
source
.
strip
().
split
())
output_str
=
""
output_str
=
""
for
idx
,
target
in
enumerate
(
sent_list
[
1
:]):
for
idx
,
target
in
enumerate
(
sent_list
[
1
:]):
try
:
try
:
if
args
.
segmented
:
if
args
.
segmented
:
target
=
target
.
strip
()
target
=
target
.
strip
()
else
:
else
:
target
=
""
.
join
(
target
.
strip
().
split
())
target
=
""
.
join
(
target
.
strip
().
split
())
if
not
args
.
no_simplified
:
if
not
args
.
no_simplified
:
target
=
cc
.
convert
(
target
)
target
=
cc
.
convert
(
target
)
source_tokenized
,
target_tokenized
=
sentence_to_tokenized
[
source
],
sentence_to_tokenized
[
target
]
source_tokenized
,
target_tokenized
=
sentence_to_tokenized
[
source
],
sentence_to_tokenized
[
target
]
out
,
cors
=
annotator
(
source_tokenized
,
target_tokenized
,
idx
)
out
,
cors
=
annotator
(
source_tokenized
,
target_tokenized
,
idx
)
if
idx
==
0
:
if
idx
==
0
:
output_str
+=
""
.
join
(
out
[:
-
1
])
output_str
+=
""
.
join
(
out
[:
-
1
])
else
:
else
:
output_str
+=
""
.
join
(
out
[
1
:
-
1
])
output_str
+=
""
.
join
(
out
[
1
:
-
1
])
except
Exception
:
except
Exception
:
raise
Exception
raise
Exception
return
output_str
return
output_str
def
annotate
(
line
):
def
annotate
(
line
):
"""
"""
:param line:
:param line:
:return:
:return:
"""
"""
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
source
=
sent_list
[
0
]
source
=
sent_list
[
0
]
if
args
.
segmented
:
if
args
.
segmented
:
source
=
source
.
strip
()
source
=
source
.
strip
()
else
:
else
:
source
=
""
.
join
(
source
.
strip
().
split
())
source
=
""
.
join
(
source
.
strip
().
split
())
output_str
=
""
output_str
=
""
for
idx
,
target
in
enumerate
(
sent_list
[
1
:]):
for
idx
,
target
in
enumerate
(
sent_list
[
1
:]):
try
:
try
:
if
args
.
segmented
:
if
args
.
segmented
:
target
=
target
.
strip
()
target
=
target
.
strip
()
else
:
else
:
target
=
""
.
join
(
target
.
strip
().
split
())
target
=
""
.
join
(
target
.
strip
().
split
())
if
not
args
.
no_simplified
:
if
not
args
.
no_simplified
:
target
=
cc
.
convert
(
target
)
target
=
cc
.
convert
(
target
)
source_tokenized
,
target_tokenized
=
sentence_to_tokenized
[
source
],
sentence_to_tokenized
[
target
]
source_tokenized
,
target_tokenized
=
sentence_to_tokenized
[
source
],
sentence_to_tokenized
[
target
]
out
,
cors
=
annotator
(
source_tokenized
,
target_tokenized
,
idx
)
out
,
cors
=
annotator
(
source_tokenized
,
target_tokenized
,
idx
)
if
idx
==
0
:
if
idx
==
0
:
output_str
+=
""
.
join
(
out
[:
-
1
])
output_str
+=
""
.
join
(
out
[:
-
1
])
else
:
else
:
output_str
+=
""
.
join
(
out
[
1
:
-
1
])
output_str
+=
""
.
join
(
out
[
1
:
-
1
])
except
Exception
:
except
Exception
:
raise
Exception
raise
Exception
return
output_str
return
output_str
def
firsttime_process
(
args
):
def
firsttime_process
(
args
):
tokenizer
=
Tokenizer
(
args
.
granularity
,
args
.
device
,
args
.
segmented
,
args
.
bpe
)
tokenizer
=
Tokenizer
(
args
.
granularity
,
args
.
device
,
args
.
segmented
,
args
.
bpe
)
global
annotator
,
sentence_to_tokenized
global
annotator
,
sentence_to_tokenized
annotator
=
Annotator
.
create_default
(
args
.
granularity
,
args
.
multi_cheapest_strategy
)
annotator
=
Annotator
.
create_default
(
args
.
granularity
,
args
.
multi_cheapest_strategy
)
lines
=
open
(
args
.
file
,
"r"
,
encoding
=
"utf-8"
).
read
().
strip
().
split
(
"
\n
"
)
# format: id src tgt1 tgt2...
lines
=
open
(
args
.
file
,
"r"
,
encoding
=
"utf-8"
).
read
().
strip
().
split
(
"
\n
"
)
# format: id src tgt1 tgt2...
# error_types = []
# error_types = []
with
open
(
args
.
output
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
args
.
output
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
count
=
0
count
=
0
sentence_set
=
set
()
sentence_set
=
set
()
sentence_to_tokenized
=
{}
sentence_to_tokenized
=
{}
for
line
in
lines
:
for
line
in
lines
:
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
for
idx
,
sent
in
enumerate
(
sent_list
):
for
idx
,
sent
in
enumerate
(
sent_list
):
if
args
.
segmented
:
if
args
.
segmented
:
# print(sent)
# print(sent)
sent
=
sent
.
strip
()
sent
=
sent
.
strip
()
else
:
else
:
sent
=
""
.
join
(
sent
.
split
()).
strip
()
sent
=
""
.
join
(
sent
.
split
()).
strip
()
if
idx
>=
1
:
if
idx
>=
1
:
if
not
args
.
no_simplified
:
if
not
args
.
no_simplified
:
sentence_set
.
add
(
cc
.
convert
(
sent
))
sentence_set
.
add
(
cc
.
convert
(
sent
))
else
:
else
:
sentence_set
.
add
(
sent
)
sentence_set
.
add
(
sent
)
else
:
else
:
sentence_set
.
add
(
sent
)
sentence_set
.
add
(
sent
)
batch
=
[]
batch
=
[]
for
sent
in
tqdm
(
sentence_set
):
for
sent
in
tqdm
(
sentence_set
):
count
+=
1
count
+=
1
if
sent
:
if
sent
:
batch
.
append
(
sent
)
batch
.
append
(
sent
)
if
count
%
args
.
batch_size
==
0
:
if
count
%
args
.
batch_size
==
0
:
results
=
tokenizer
(
batch
)
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
batch
=
[]
batch
=
[]
if
batch
:
if
batch
:
results
=
tokenizer
(
batch
)
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
timeout_indices
=
[]
timeout_indices
=
[]
# 单进程模式
# 单进程模式
for
idx
,
line
in
enumerate
(
tqdm
(
lines
)):
for
idx
,
line
in
enumerate
(
tqdm
(
lines
)):
try
:
try
:
ret
=
annotate_with_time_out
(
line
)
ret
=
annotate_with_time_out
(
line
)
except
Exception
:
except
Exception
:
timeout_indices
.
append
(
idx
)
timeout_indices
.
append
(
idx
)
return
timeout_indices
return
timeout_indices
def
main
(
args
):
def
main
(
args
):
timeout_indices
=
firsttime_process
(
args
)
timeout_indices
=
firsttime_process
(
args
)
tokenizer
=
Tokenizer
(
args
.
granularity
,
args
.
device
,
args
.
segmented
,
args
.
bpe
)
tokenizer
=
Tokenizer
(
args
.
granularity
,
args
.
device
,
args
.
segmented
,
args
.
bpe
)
global
annotator
,
sentence_to_tokenized
global
annotator
,
sentence_to_tokenized
annotator
=
Annotator
.
create_default
(
args
.
granularity
,
args
.
multi_cheapest_strategy
)
annotator
=
Annotator
.
create_default
(
args
.
granularity
,
args
.
multi_cheapest_strategy
)
lines
=
open
(
args
.
file
,
"r"
,
encoding
=
"utf-8"
).
read
().
strip
().
split
(
"
\n
"
)
lines
=
open
(
args
.
file
,
"r"
,
encoding
=
"utf-8"
).
read
().
strip
().
split
(
"
\n
"
)
new_lines
=
[]
# format: id src tgt1 tgt2...
new_lines
=
[]
# format: id src tgt1 tgt2...
with
open
(
args
.
output
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
args
.
output
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
count
=
0
count
=
0
sentence_set
=
set
()
sentence_set
=
set
()
sentence_to_tokenized
=
{}
sentence_to_tokenized
=
{}
for
line_idx
,
line
in
enumerate
(
lines
):
for
line_idx
,
line
in
enumerate
(
lines
):
if
line_idx
in
timeout_indices
:
if
line_idx
in
timeout_indices
:
# print(f"line before split: {line}")
# print(f"line before split: {line}")
line_split
=
line
.
split
(
"
\t
"
)
line_split
=
line
.
split
(
"
\t
"
)
line_number
,
sent_list
=
line_split
[
0
],
line_split
[
1
:]
line_number
,
sent_list
=
line_split
[
0
],
line_split
[
1
:]
assert
len
(
sent_list
)
==
2
assert
len
(
sent_list
)
==
2
sent_list
[
-
1
]
=
" 无"
sent_list
[
-
1
]
=
" 无"
line
=
line_number
+
"
\t
"
+
"
\t
"
.
join
(
sent_list
)
line
=
line_number
+
"
\t
"
+
"
\t
"
.
join
(
sent_list
)
# print(f"line time out: {line}")
# print(f"line time out: {line}")
new_lines
.
append
(
line
)
new_lines
.
append
(
line
)
else
:
else
:
new_lines
.
append
(
line
)
new_lines
.
append
(
line
)
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
for
idx
,
sent
in
enumerate
(
sent_list
):
for
idx
,
sent
in
enumerate
(
sent_list
):
if
args
.
segmented
:
if
args
.
segmented
:
# print(sent)
# print(sent)
sent
=
sent
.
strip
()
sent
=
sent
.
strip
()
else
:
else
:
sent
=
""
.
join
(
sent
.
split
()).
strip
()
sent
=
""
.
join
(
sent
.
split
()).
strip
()
if
idx
>=
1
:
if
idx
>=
1
:
if
not
args
.
no_simplified
:
if
not
args
.
no_simplified
:
sentence_set
.
add
(
cc
.
convert
(
sent
))
sentence_set
.
add
(
cc
.
convert
(
sent
))
else
:
else
:
sentence_set
.
add
(
sent
)
sentence_set
.
add
(
sent
)
else
:
else
:
sentence_set
.
add
(
sent
)
sentence_set
.
add
(
sent
)
batch
=
[]
batch
=
[]
for
sent
in
tqdm
(
sentence_set
):
for
sent
in
tqdm
(
sentence_set
):
count
+=
1
count
+=
1
if
sent
:
if
sent
:
batch
.
append
(
sent
)
batch
.
append
(
sent
)
if
count
%
args
.
batch_size
==
0
:
if
count
%
args
.
batch_size
==
0
:
results
=
tokenizer
(
batch
)
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
batch
=
[]
batch
=
[]
if
batch
:
if
batch
:
results
=
tokenizer
(
batch
)
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
# 单进程模式
# 单进程模式
lines
=
new_lines
lines
=
new_lines
for
idx
,
line
in
enumerate
(
tqdm
(
lines
)):
for
idx
,
line
in
enumerate
(
tqdm
(
lines
)):
ret
=
annotate
(
line
)
ret
=
annotate
(
line
)
f
.
write
(
ret
)
f
.
write
(
ret
)
f
.
write
(
"
\n
"
)
f
.
write
(
"
\n
"
)
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# with Pool(args.worker_num) as pool:
# with Pool(args.worker_num) as pool:
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# if ret:
# if ret:
# f.write(ret)
# f.write(ret)
# f.write("\n")
# f.write("\n")
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Choose input file to annotate"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Choose input file to annotate"
)
parser
.
add_argument
(
"-f"
,
"--file"
,
type
=
str
,
required
=
True
,
help
=
"Input parallel file"
)
parser
.
add_argument
(
"-f"
,
"--file"
,
type
=
str
,
required
=
True
,
help
=
"Input parallel file"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
type
=
str
,
help
=
"Output file"
,
required
=
True
)
parser
.
add_argument
(
"-o"
,
"--output"
,
type
=
str
,
help
=
"Output file"
,
required
=
True
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
type
=
int
,
help
=
"The size of batch"
,
default
=
128
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
type
=
int
,
help
=
"The size of batch"
,
default
=
128
)
parser
.
add_argument
(
"-d"
,
"--device"
,
type
=
int
,
help
=
"The ID of GPU"
,
default
=
0
)
parser
.
add_argument
(
"-d"
,
"--device"
,
type
=
int
,
help
=
"The ID of GPU"
,
default
=
0
)
parser
.
add_argument
(
"-w"
,
"--worker_num"
,
type
=
int
,
help
=
"The number of workers"
,
default
=
16
)
parser
.
add_argument
(
"-w"
,
"--worker_num"
,
type
=
int
,
help
=
"The number of workers"
,
default
=
16
)
parser
.
add_argument
(
"-g"
,
"--granularity"
,
type
=
str
,
help
=
"Choose char-level or word-level evaluation"
,
default
=
"char"
)
parser
.
add_argument
(
"-g"
,
"--granularity"
,
type
=
str
,
help
=
"Choose char-level or word-level evaluation"
,
default
=
"char"
)
parser
.
add_argument
(
"-m"
,
"--merge"
,
help
=
"Whether merge continuous replacement/deletion/insertion"
,
action
=
"store_true"
)
parser
.
add_argument
(
"-m"
,
"--merge"
,
help
=
"Whether merge continuous replacement/deletion/insertion"
,
action
=
"store_true"
)
parser
.
add_argument
(
"-s"
,
"--multi_cheapest_strategy"
,
type
=
str
,
choices
=
[
"first"
,
"all"
],
default
=
"all"
)
parser
.
add_argument
(
"-s"
,
"--multi_cheapest_strategy"
,
type
=
str
,
choices
=
[
"first"
,
"all"
],
default
=
"all"
)
parser
.
add_argument
(
"--segmented"
,
help
=
"Whether tokens have been segmented"
,
action
=
"store_true"
)
# 支持提前token化,用空格隔开
parser
.
add_argument
(
"--segmented"
,
help
=
"Whether tokens have been segmented"
,
action
=
"store_true"
)
# 支持提前token化,用空格隔开
parser
.
add_argument
(
"--no_simplified"
,
help
=
"Whether simplifying chinese"
,
action
=
"store_true"
)
# 将所有corrections转换为简体中文
parser
.
add_argument
(
"--no_simplified"
,
help
=
"Whether simplifying chinese"
,
action
=
"store_true"
)
# 将所有corrections转换为简体中文
parser
.
add_argument
(
"--bpe"
,
help
=
"Whether to use bpe"
,
action
=
"store_true"
)
# 支持 bpe 切分英文单词
parser
.
add_argument
(
"--bpe"
,
help
=
"Whether to use bpe"
,
action
=
"store_true"
)
# 支持 bpe 切分英文单词
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
opencompass/runners/dlc.py
View file @
4dd9a3fc
import
datetime
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
import
random
import
random
import
re
import
subprocess
import
subprocess
import
sys
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
mmengine
import
mmengine
from
mmengine.config
import
ConfigDict
from
mmengine.config
import
ConfigDict
...
@@ -43,6 +46,11 @@ class DLCRunner(BaseRunner):
...
@@ -43,6 +46,11 @@ class DLCRunner(BaseRunner):
self
.
max_num_workers
=
max_num_workers
self
.
max_num_workers
=
max_num_workers
self
.
retry
=
retry
self
.
retry
=
retry
logger
=
get_logger
()
logger
.
warning
(
'To ensure the integrity of the log results, the log displayed '
f
'by
{
self
.
__class__
.
__name__
}
has a 10-second delay.'
)
def
launch
(
self
,
tasks
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Tuple
[
str
,
int
]]:
def
launch
(
self
,
tasks
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Tuple
[
str
,
int
]]:
"""Launch multiple tasks.
"""Launch multiple tasks.
...
@@ -63,18 +71,23 @@ class DLCRunner(BaseRunner):
...
@@ -63,18 +71,23 @@ class DLCRunner(BaseRunner):
status
=
[
self
.
_launch
(
task
,
random_sleep
=
False
)
for
task
in
tasks
]
status
=
[
self
.
_launch
(
task
,
random_sleep
=
False
)
for
task
in
tasks
]
return
status
return
status
def
_launch
(
self
,
cfg
:
ConfigDict
,
random_sleep
:
bool
=
Tru
e
):
def
_launch
(
self
,
cfg
:
ConfigDict
,
random_sleep
:
Optional
[
bool
]
=
Non
e
):
"""Launch a single task.
"""Launch a single task.
Args:
Args:
cfg (ConfigDict): Task config.
cfg (ConfigDict): Task config.
random_sleep (bool): Whether to sleep for a random time before
random_sleep (bool): Whether to sleep for a random time before
running the command. This avoids cluster error when launching
running the command. When Aliyun has many tasks to schedule,
multiple tasks at the same time. Default: True.
its stability decreases. Therefore, when we need to submit a
large number of tasks at once, we adopt the "random_sleep"
strategy. Tasks that would have been submitted all at once are
now evenly spread out over a 10-second period. Default: None.
Returns:
Returns:
tuple[str, int]: Task name and exit code.
tuple[str, int]: Task name and exit code.
"""
"""
if
random_sleep
is
None
:
random_sleep
=
(
self
.
max_num_workers
>
32
)
task
=
TASKS
.
build
(
dict
(
cfg
=
cfg
,
type
=
self
.
task_cfg
[
'type'
]))
task
=
TASKS
.
build
(
dict
(
cfg
=
cfg
,
type
=
self
.
task_cfg
[
'type'
]))
num_gpus
=
task
.
num_gpus
num_gpus
=
task
.
num_gpus
...
@@ -116,7 +129,7 @@ class DLCRunner(BaseRunner):
...
@@ -116,7 +129,7 @@ class DLCRunner(BaseRunner):
# Run command with retry
# Run command with retry
if
self
.
debug
:
if
self
.
debug
:
stdout
=
None
stdout
=
sys
.
stdout
else
:
else
:
out_path
=
task
.
get_log_path
(
file_extension
=
'out'
)
out_path
=
task
.
get_log_path
(
file_extension
=
'out'
)
mmengine
.
mkdir_or_exist
(
osp
.
split
(
out_path
)[
0
])
mmengine
.
mkdir_or_exist
(
osp
.
split
(
out_path
)[
0
])
...
@@ -124,30 +137,92 @@ class DLCRunner(BaseRunner):
...
@@ -124,30 +137,92 @@ class DLCRunner(BaseRunner):
if
random_sleep
:
if
random_sleep
:
time
.
sleep
(
random
.
randint
(
0
,
10
))
time
.
sleep
(
random
.
randint
(
0
,
10
))
result
=
subprocess
.
run
(
cmd
,
shell
=
True
,
text
=
True
,
stdout
=
stdout
,
stderr
=
stdout
)
def
_run_within_retry
():
try
:
process
=
subprocess
.
Popen
(
cmd
,
shell
=
True
,
text
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
job_id
=
None
job_allocated
=
False
job_finished
=
False
last_end_time
=
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%dT%H:%M:%SZ'
)
while
True
:
if
not
job_allocated
:
line
=
process
.
stdout
.
readline
()
if
not
line
:
break
match
=
re
.
search
(
r
'(dlc[0-9a-z]+)'
,
line
)
if
match
and
job_id
is
None
:
job_id
=
match
.
group
(
1
)
stdout
.
write
(
line
)
match
=
re
.
search
(
r
'Job .* is \[Running\]'
,
line
)
if
match
:
job_allocated
=
True
else
:
try
:
process
.
wait
(
10
)
except
subprocess
.
TimeoutExpired
:
pass
else
:
job_finished
=
True
if
job_finished
:
this_end_time
=
datetime
.
datetime
.
now
(
).
strftime
(
'%Y-%m-%dT%H:%M:%SZ'
)
else
:
this_end_time
=
(
datetime
.
datetime
.
now
()
-
datetime
.
timedelta
(
seconds
=
10
)
).
strftime
(
'%Y-%m-%dT%H:%M:%SZ'
)
logs_cmd
=
(
'dlc logs'
f
'
{
job_id
}
{
job_id
}
-worker-0'
f
' --start_time
{
last_end_time
}
'
f
' --end_time
{
this_end_time
}
'
f
" -c
{
self
.
aliyun_cfg
[
'dlc_config_path'
]
}
"
)
log_process
=
subprocess
.
Popen
(
logs_cmd
,
shell
=
True
,
text
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
log_output
,
log_err
=
log_process
.
communicate
()
log_output
=
'
\n
'
.
join
(
log_output
.
split
(
'
\n
'
)[
2
:])
stdout
.
write
(
log_output
)
last_end_time
=
this_end_time
stdout
.
flush
()
if
job_finished
:
break
process
.
wait
()
return
process
.
returncode
finally
:
if
job_id
is
not
None
:
cancel_cmd
=
(
'dlc stop job'
f
'
{
job_id
}
'
f
" -c
{
self
.
aliyun_cfg
[
'dlc_config_path'
]
}
"
' -f'
)
subprocess
.
run
(
cancel_cmd
,
shell
=
True
,
text
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
return_code
=
_run_within_retry
()
retry
=
self
.
retry
retry
=
self
.
retry
output_paths
=
task
.
get_output_paths
()
output_paths
=
task
.
get_output_paths
()
while
self
.
_job_failed
(
result
.
returncode
,
while
self
.
_job_failed
(
return_code
,
output_paths
)
and
retry
>
0
:
output_paths
)
and
retry
>
0
:
retry
-=
1
retry
-=
1
if
random_sleep
:
time
.
sleep
(
random
.
randint
(
0
,
10
))
# Re-generate command to refresh ports.
cmd
=
get_cmd
()
cmd
=
get_cmd
()
result
=
subprocess
.
run
(
cmd
,
return_code
=
_run_within_retry
()
shell
=
True
,
text
=
True
,
stdout
=
stdout
,
stderr
=
stdout
)
finally
:
finally
:
# Clean up
# Clean up
os
.
remove
(
param_file
)
os
.
remove
(
param_file
)
return
task_name
,
result
.
returncode
return
task_name
,
return_code
def
_job_failed
(
self
,
return_code
:
int
,
output_paths
:
List
[
str
])
->
bool
:
def
_job_failed
(
self
,
return_code
:
int
,
output_paths
:
List
[
str
])
->
bool
:
return
return_code
!=
0
or
not
all
(
return
return_code
!=
0
or
not
all
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment