Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
861942ab
Unverified
Commit
861942ab
authored
Oct 13, 2023
by
Leymore
Committed by
GitHub
Oct 13, 2023
Browse files
[Feature] Add lawbench (#460)
* add lawbench * update requirements * update
parent
fbf5089c
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2931 additions
and
0 deletions
+2931
-0
opencompass/datasets/lawbench/evaluation_functions/wbfl.py
opencompass/datasets/lawbench/evaluation_functions/wbfl.py
+42
-0
opencompass/datasets/lawbench/evaluation_functions/wsjd.py
opencompass/datasets/lawbench/evaluation_functions/wsjd.py
+50
-0
opencompass/datasets/lawbench/evaluation_functions/xxcq.py
opencompass/datasets/lawbench/evaluation_functions/xxcq.py
+17
-0
opencompass/datasets/lawbench/evaluation_functions/ydlj.py
opencompass/datasets/lawbench/evaluation_functions/ydlj.py
+17
-0
opencompass/datasets/lawbench/evaluation_functions/yqzy.py
opencompass/datasets/lawbench/evaluation_functions/yqzy.py
+18
-0
opencompass/datasets/lawbench/evaluation_functions/zxfl.py
opencompass/datasets/lawbench/evaluation_functions/zxfl.py
+27
-0
opencompass/datasets/lawbench/lawbench.py
opencompass/datasets/lawbench/lawbench.py
+83
-0
opencompass/datasets/lawbench/utils/char_smi.py
opencompass/datasets/lawbench/utils/char_smi.py
+456
-0
opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py
...pass/datasets/lawbench/utils/compare_m2_for_evaluation.py
+433
-0
opencompass/datasets/lawbench/utils/comprehension_scores.py
opencompass/datasets/lawbench/utils/comprehension_scores.py
+82
-0
opencompass/datasets/lawbench/utils/function_utils.py
opencompass/datasets/lawbench/utils/function_utils.py
+49
-0
opencompass/datasets/lawbench/utils/modules/alignment.py
opencompass/datasets/lawbench/utils/modules/alignment.py
+334
-0
opencompass/datasets/lawbench/utils/modules/annotator.py
opencompass/datasets/lawbench/utils/modules/annotator.py
+76
-0
opencompass/datasets/lawbench/utils/modules/classifier.py
opencompass/datasets/lawbench/utils/modules/classifier.py
+151
-0
opencompass/datasets/lawbench/utils/modules/merger.py
opencompass/datasets/lawbench/utils/modules/merger.py
+273
-0
opencompass/datasets/lawbench/utils/modules/tokenization.py
opencompass/datasets/lawbench/utils/modules/tokenization.py
+346
-0
opencompass/datasets/lawbench/utils/modules/tokenizer.py
opencompass/datasets/lawbench/utils/modules/tokenizer.py
+92
-0
opencompass/datasets/lawbench/utils/parallel_to_m2.py
opencompass/datasets/lawbench/utils/parallel_to_m2.py
+221
-0
opencompass/datasets/lawbench/utils/rc_f1.py
opencompass/datasets/lawbench/utils/rc_f1.py
+158
-0
requirements/runtime.txt
requirements/runtime.txt
+6
-0
No files found.
opencompass/datasets/lawbench/evaluation_functions/wbfl.py
0 → 100644
View file @
861942ab
"""
task: multiple choice classification
metric: F1 score
婚姻文本分类
"""
def
compute_wbfl
(
data_dict
):
"""
A reference (R) contains a list of options, each option is from the option_list.
We will extract the options appearing in the prediction and convert them into a set (P).
We compute the F1 score between the prediction (P) and the reference (R).
"""
score_list
,
abstentions
=
[],
0
option_list
=
[
"婚后有子女"
,
"限制行为能力子女抚养"
,
"有夫妻共同财产"
,
"支付抚养费"
,
"不动产分割"
,
"婚后分局"
,
"二次起诉离婚"
,
"按月给付抚养费"
,
"准予离婚"
,
"有夫妻共同债务"
,
"婚前个人财产"
,
"法定离婚"
,
"不履行家庭义务"
,
"存在非婚生子"
,
"适当帮助"
,
"不履行离婚协议"
,
"损害赔偿"
,
"感情不和分居满二年"
,
"子女随非抚养权人生活"
,
"婚后个人财产"
]
for
example
in
data_dict
:
question
,
prediction
,
answer
=
example
[
"origin_prompt"
],
example
[
"prediction"
],
example
[
"refr"
]
assert
answer
.
startswith
(
"类别:"
)
and
answer
.
endswith
(
"。"
),
f
"answer:
{
answer
}
, question:
{
question
}
"
gt_list
=
(
answer
[
3
:
-
1
].
split
(
"、"
))
for
gt
in
gt_list
:
assert
gt
in
option_list
,
f
"gt:
{
gt
}
, question:
{
question
}
"
gt_set
=
set
(
gt_list
)
prediction_list
=
[]
for
option
in
option_list
:
if
option
in
prediction
:
prediction_list
.
append
(
option
)
if
len
(
prediction_list
)
==
0
:
abstentions
+=
1
predict_set
=
set
(
prediction_list
)
precision
=
len
(
gt_set
.
intersection
(
predict_set
))
/
len
(
predict_set
)
if
len
(
predict_set
)
!=
0
else
0
recall
=
len
(
gt_set
.
intersection
(
predict_set
))
/
len
(
gt_set
)
if
len
(
gt_set
)
!=
0
else
0
f1_score
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
(
precision
+
recall
)
!=
0
else
0
score_list
.
append
(
f1_score
)
# compute the accuracy of score_list
final_f1_score
=
sum
(
score_list
)
/
len
(
score_list
)
return
{
'score'
:
final_f1_score
,
'abstention_rate'
:
abstentions
/
len
(
data_dict
)}
opencompass/datasets/lawbench/evaluation_functions/wsjd.py
0 → 100644
View file @
861942ab
import
re
import
os
import
subprocess
"""
Task: legal document grammar correction
Metric: F0.5 score
文书校对
"""
def
compute_wsjd
(
data_dict
):
origins
,
references
,
predictions
=
[],
[],
[]
for
example
in
data_dict
:
question
,
prediction
,
answer
=
example
[
"origin_prompt"
],
example
[
"prediction"
],
example
[
"refr"
]
if
isinstance
(
question
,
list
):
question
=
question
[
0
][
'prompt'
]
start
=
question
.
index
(
'句子:
\n
'
)
+
4
origins
.
append
(
re
.
sub
(
r
'\n|\t'
,
''
,
question
[
start
:].
split
(
'
\n
'
)[
0
]))
# truncate predictions >5 tokens longer than the reference
prediction
=
re
.
sub
(
r
'\n|\t'
,
''
,
prediction
)
if
len
(
prediction
)
-
len
(
answer
)
>
5
:
prediction
=
prediction
[:
len
(
answer
)
+
5
]
if
len
(
prediction
)
==
0
:
prediction
=
"无内容"
predictions
.
append
(
prediction
)
references
.
append
(
re
.
sub
(
r
'\n|\t'
,
''
,
answer
))
#generate input files for ChERRANT
preds
=
[
f
'
{
i
}
\t
{
origin
}
\t
{
prediction
}
\n
'
for
i
,
(
origin
,
prediction
)
in
enumerate
(
zip
(
origins
,
predictions
))]
golds
=
[
f
'
{
i
}
\t
{
origin
}
\t
{
reference
}
\n
'
for
i
,
(
origin
,
reference
)
in
enumerate
(
zip
(
origins
,
references
))]
now_path
=
os
.
path
.
abspath
(
os
.
getcwd
())
utils_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
'..'
,
'..'
,
'utils'
))
uid
=
os
.
getuid
()
os
.
chdir
(
utils_path
)
with
open
(
f
'/tmp/tmp_pred_
{
uid
}
.para'
,
'w'
)
as
f
:
f
.
writelines
(
preds
)
with
open
(
f
'/tmp/tmp_gold_
{
uid
}
.para'
,
'w'
)
as
f
:
f
.
writelines
(
golds
)
os
.
environ
[
'KMP_DUPLICATE_LIB_OK'
]
=
'True'
os
.
system
(
f
'python3 parallel_to_m2.py -f /tmp/tmp_pred_
{
uid
}
.para -o /tmp/tmp_pred_
{
uid
}
.para.m2 -g char'
)
os
.
system
(
f
'python3 parallel_to_m2.py -f /tmp/tmp_gold_
{
uid
}
.para -o /tmp/tmp_gold_
{
uid
}
.para.m2 -g char'
)
output
=
subprocess
.
check_output
(
f
"python3 compare_m2_for_evaluation.py -hyp /tmp/tmp_pred_
{
uid
}
.para.m2 -ref /tmp/tmp_gold_
{
uid
}
.para.m2"
,
shell
=
True
)
score
=
float
(
output
.
decode
().
split
(
'
\t
'
)[
-
1
].
split
(
'
\n
'
)[
0
])
#remove prediction files
os
.
remove
(
f
'/tmp/tmp_pred_
{
uid
}
.para'
)
os
.
remove
(
f
'/tmp/tmp_gold_
{
uid
}
.para'
)
os
.
remove
(
f
'/tmp/tmp_pred_
{
uid
}
.para.m2'
)
os
.
remove
(
f
'/tmp/tmp_gold_
{
uid
}
.para.m2'
)
os
.
chdir
(
now_path
)
return
{
"score"
:
score
}
opencompass/datasets/lawbench/evaluation_functions/xxcq.py
0 → 100644
View file @
861942ab
from
..utils.comprehension_scores
import
compute_ie_f1
"""
task: information extraction
metric: F1 score
信息抽取
"""
def
compute_xxcq
(
data_dict
):
references
,
predictions
=
[],
[]
for
example
in
data_dict
:
question
,
prediction
,
answer
=
example
[
"origin_prompt"
],
example
[
"prediction"
],
example
[
"refr"
]
predictions
.
append
(
prediction
)
references
.
append
(
answer
)
return
compute_ie_f1
(
predictions
,
references
,
{
"犯罪嫌疑人"
,
"受害人"
,
"被盗货币"
,
"物品价值"
,
"盗窃获利"
,
"被盗物品"
,
"作案工具"
,
"时间"
,
"地点"
,
"组织机构"
})
opencompass/datasets/lawbench/evaluation_functions/ydlj.py
0 → 100644
View file @
861942ab
from
..utils.comprehension_scores
import
compute_rc_f1
"""
Task: machine reading comprehension
Metric: F1 score
法律阅读理解
"""
def
compute_ydlj
(
data_dict
):
references
,
predictions
=
[],
[]
for
example
in
data_dict
:
question
,
prediction
,
answer
=
example
[
"origin_prompt"
],
example
[
"prediction"
],
example
[
"refr"
]
answer
=
answer
.
replace
(
"回答:"
,
""
)
predictions
.
append
(
prediction
)
references
.
append
(
answer
)
f1_score
=
compute_rc_f1
(
predictions
,
references
)
return
f1_score
opencompass/datasets/lawbench/evaluation_functions/yqzy.py
0 → 100644
View file @
861942ab
from
..utils.function_utils
import
compute_rouge
#舆情摘要
def
compute_yqzy
(
data_dict
):
"""
Compute the ROUGE-L score between the prediction and the reference
"""
references
,
predictions
=
[],
[]
for
example
in
data_dict
:
question
,
prediction
,
answer
=
example
[
"origin_prompt"
],
example
[
"prediction"
],
example
[
"refr"
]
predictions
.
append
(
prediction
)
references
.
append
(
answer
)
# compute the accuracy of score_list
rouge_scores
=
compute_rouge
(
predictions
,
references
)
rouge_ls
=
[
score
[
"rouge-l"
][
"f"
]
for
score
in
rouge_scores
]
average_rouge_l
=
sum
(
rouge_ls
)
/
len
(
rouge_ls
)
return
{
"score"
:
average_rouge_l
}
opencompass/datasets/lawbench/evaluation_functions/zxfl.py
0 → 100644
View file @
861942ab
from
..utils.function_utils
import
multi_choice_judge
"""
task: multiple choice classification
metric: accuracy
咨询分类
"""
def
compute_zxfl
(
data_dict
):
"""
A reference (R) contains a list of options, each option is from the option_list.
We will extract the options appearing in the prediction and convert them into a set (P).
We compute the accuracy between the prediction (P) and the reference (R).
"""
score_list
,
abstentions
=
[],
0
option_list
=
[
'婚姻家庭'
,
'劳动纠纷'
,
'交通事故'
,
'债权债务'
,
'刑事辩护'
,
'合同纠纷'
,
'房产纠纷'
,
'侵权'
,
'公司法'
,
'医疗纠纷'
,
'拆迁安置'
,
'行政诉讼'
,
'建设工程'
,
'知识产权'
,
'综合咨询'
,
'人身损害'
,
'涉外法律'
,
'海事海商'
,
'消费权益'
,
'抵押担保'
]
for
example
in
data_dict
:
question
,
prediction
,
answer
=
example
[
"origin_prompt"
],
example
[
"prediction"
],
example
[
"refr"
]
judge
=
multi_choice_judge
(
prediction
,
option_list
,
answer
)
score_list
.
append
(
judge
[
"score"
])
abstentions
+=
judge
[
"abstention"
]
# compute the accuracy of score_list
final_accuracy_score
=
sum
(
score_list
)
/
len
(
score_list
)
return
{
'score'
:
final_accuracy_score
,
'abstention_rate'
:
abstentions
/
len
(
data_dict
)}
opencompass/datasets/lawbench/lawbench.py
0 → 100644
View file @
861942ab
import
json
import
os
from
datasets
import
Dataset
from
opencompass.openicl.icl_evaluator
import
BaseEvaluator
from
opencompass.registry
import
ICL_EVALUATORS
,
LOAD_DATASET
from
..base
import
BaseDataset
from
.evaluation_functions
import
(
cjft
,
flzx
,
ftcs
,
jdzy
,
jec_ac
,
jec_kd
,
jetq
,
lblj
,
ljp_accusation
,
ljp_article
,
ljp_imprison
,
sjjc
,
wbfl
,
wsjd
,
xxcq
,
ydlj
,
yqzy
,
zxfl
)
@
LOAD_DATASET
.
register_module
()
class
LawBenchDataset
(
BaseDataset
):
@
staticmethod
def
load
(
path
:
str
,
index
:
str
)
->
Dataset
:
path
=
os
.
path
.
join
(
path
,
index
+
'.json'
)
with
open
(
path
,
'r'
)
as
f
:
data
=
json
.
load
(
f
)
return
Dataset
.
from_list
(
data
)
funct_dict
=
{
'1-1'
:
ftcs
.
compute_ftcs
,
'1-2'
:
jec_kd
.
compute_jec_kd
,
'2-1'
:
wsjd
.
compute_wsjd
,
'2-2'
:
jdzy
.
compute_jdzy
,
'2-3'
:
wbfl
.
compute_wbfl
,
'2-4'
:
zxfl
.
compute_zxfl
,
'2-5'
:
ydlj
.
compute_ydlj
,
'2-6'
:
xxcq
.
compute_xxcq
,
'2-7'
:
yqzy
.
compute_yqzy
,
'2-8'
:
lblj
.
compute_lblj
,
'2-9'
:
sjjc
.
compute_sjjc
,
'2-10'
:
sjjc
.
compute_cfcy
,
'3-1'
:
ljp_article
.
compute_ljp_article
,
'3-2'
:
cjft
.
compute_cjft
,
'3-3'
:
ljp_accusation
.
compute_ljp_accusation
,
'3-4'
:
ljp_imprison
.
compute_ljp_imprison
,
'3-5'
:
ljp_imprison
.
compute_ljp_imprison
,
'3-6'
:
jec_ac
.
compute_jec_ac
,
'3-7'
:
jetq
.
compute_jetq
,
'3-8'
:
flzx
.
compute_flzx
,
}
class
LawBenchEvaluator
(
BaseEvaluator
):
def
__init__
(
self
,
index
)
->
None
:
super
().
__init__
()
self
.
index
=
index
def
score
(
self
,
predictions
,
references
,
origin_prompt
):
if
len
(
predictions
)
!=
len
(
references
):
return
{
'error'
:
'predictions and references have different '
'length'
}
data_dict
=
[{
'origin_prompt'
:
origin_prompt
[
i
],
'prediction'
:
predictions
[
i
],
'refr'
:
references
[
i
],
}
for
i
in
range
(
len
(
predictions
))]
scores
=
funct_dict
[
self
.
index
](
data_dict
)
scores
=
{
k
:
v
*
100
for
k
,
v
in
scores
.
items
()}
return
scores
for
index
in
funct_dict
:
# fix classic closure problem
def
_register
(
index
):
ICL_EVALUATORS
.
register_module
(
name
=
'LawBenchEvaluator_'
+
index
.
replace
(
'-'
,
'_'
),
module
=
lambda
*
args
,
**
kwargs
:
LawBenchEvaluator
(
index
=
index
,
*
args
,
**
kwargs
))
_register
(
index
)
opencompass/datasets/lawbench/utils/char_smi.py
0 → 100644
View file @
861942ab
### Copy from https://github.com/iqiyi/FASPell ###
"""
Requirements:
- java (required only if tree edit distance is used)
- numpy
"""
import
numpy
as
np
from
subprocess
import
Popen
,
PIPE
,
STDOUT
import
os
import
argparse
IDCS
=
{
'
\u2ff0
'
:
2
,
# 12 ideographic description characters and their capacity of son nodes
'
\u2ff1
'
:
2
,
'
\u2ff2
'
:
3
,
'
\u2ff3
'
:
3
,
'
\u2ff4
'
:
2
,
'
\u2ff5
'
:
2
,
'
\u2ff6
'
:
2
,
'
\u2ff7
'
:
2
,
'
\u2ff8
'
:
2
,
'
\u2ff9
'
:
2
,
'
\u2ffa
'
:
2
,
'
\u2ffb
'
:
2
,
}
PINYIN
=
{
'ā'
:
[
'a'
,
1
],
'á'
:
[
'a'
,
2
],
'ǎ'
:
[
'a'
,
3
],
'à'
:
[
'a'
,
4
],
'ē'
:
[
'e'
,
1
],
'é'
:
[
'e'
,
2
],
'ě'
:
[
'e'
,
3
],
'è'
:
[
'e'
,
4
],
'ī'
:
[
'i'
,
1
],
'í'
:
[
'i'
,
2
],
'ǐ'
:
[
'i'
,
3
],
'ì'
:
[
'i'
,
4
],
'ō'
:
[
'o'
,
1
],
'ó'
:
[
'o'
,
2
],
'ǒ'
:
[
'o'
,
3
],
'ò'
:
[
'o'
,
4
],
'ū'
:
[
'u'
,
1
],
'ú'
:
[
'u'
,
2
],
'ǔ'
:
[
'u'
,
3
],
'ù'
:
[
'u'
,
4
],
'ǖ'
:
[
'ü'
,
1
],
'ǘ'
:
[
'ü'
,
2
],
'ǚ'
:
[
'ü'
,
3
],
'ǜ'
:
[
'ü'
,
4
],
''
:
[
'm'
,
2
],
'ń'
:
[
'n'
,
2
],
'ň'
:
[
'n'
,
3
],
'ǹ'
:
[
'n'
,
4
],
}
# APTED_JAR_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'apted.jar')
APTED_JAR_PATH
=
'apted.jar'
def
tree_edit_distance
(
tree_a
,
tree_b
):
"""
We use APTED algorithm proposed by M. Pawlik and N. Augsten
github link: https://github.com/DatabaseGroup/apted
"""
p
=
Popen
([
'java'
,
'-jar'
,
APTED_JAR_PATH
,
'-t'
,
tree_a
,
tree_b
],
stdout
=
PIPE
,
stderr
=
STDOUT
)
res
=
[
line
for
line
in
p
.
stdout
]
res
=
res
[
0
]
res
=
res
.
strip
()
res
=
float
(
res
)
return
res
def
edit_distance
(
string_a
,
string_b
,
name
=
'Levenshtein'
):
"""
>>> edit_distance('abcde', 'avbcude')
2
>>> edit_distance(['至', '刂'], ['亻', '至', '刂'])
1
>>> edit_distance('fang', 'qwe')
4
>>> edit_distance('fang', 'hen')
3
"""
size_x
=
len
(
string_a
)
+
1
size_y
=
len
(
string_b
)
+
1
matrix
=
np
.
zeros
((
size_x
,
size_y
),
dtype
=
int
)
for
x
in
range
(
size_x
):
matrix
[
x
,
0
]
=
x
for
y
in
range
(
size_y
):
matrix
[
0
,
y
]
=
y
for
x
in
range
(
1
,
size_x
):
for
y
in
range
(
1
,
size_y
):
if
string_a
[
x
-
1
]
==
string_b
[
y
-
1
]:
matrix
[
x
,
y
]
=
min
(
matrix
[
x
-
1
,
y
]
+
1
,
matrix
[
x
-
1
,
y
-
1
],
matrix
[
x
,
y
-
1
]
+
1
)
else
:
if
name
==
'Levenshtein'
:
matrix
[
x
,
y
]
=
min
(
matrix
[
x
-
1
,
y
]
+
1
,
matrix
[
x
-
1
,
y
-
1
]
+
1
,
matrix
[
x
,
y
-
1
]
+
1
)
else
:
# Canonical
matrix
[
x
,
y
]
=
min
(
matrix
[
x
-
1
,
y
]
+
1
,
matrix
[
x
-
1
,
y
-
1
]
+
2
,
matrix
[
x
,
y
-
1
]
+
1
)
return
matrix
[
size_x
-
1
,
size_y
-
1
]
class
CharFuncs
(
object
):
def
__init__
(
self
,
char_meta_fname
):
self
.
data
=
self
.
load_char_meta
(
char_meta_fname
)
self
.
char_dict
=
dict
([(
c
,
0
)
for
c
in
self
.
data
])
self
.
safe
=
{
'
\u2ff0
'
:
'A'
,
# to eliminate the bug that, in Windows CMD, char ⿻ and ⿵ are encoded to be the same.
'
\u2ff1
'
:
'B'
,
'
\u2ff2
'
:
'C'
,
'
\u2ff3
'
:
'D'
,
'
\u2ff4
'
:
'E'
,
'
\u2ff5
'
:
'F'
,
'
\u2ff6
'
:
'G'
,
'
\u2ff7
'
:
'H'
,
'
\u2ff8
'
:
'I'
,
'
\u2ff9
'
:
'J'
,
'
\u2ffa
'
:
'L'
,
'
\u2ffb
'
:
'M'
,
}
@
staticmethod
def
load_char_meta
(
fname
):
data
=
{}
f
=
open
(
fname
,
'r'
,
encoding
=
'utf-8'
)
for
line
in
f
:
items
=
line
.
strip
().
split
(
'
\t
'
)
code_point
=
items
[
0
]
char
=
items
[
1
]
pronunciation
=
items
[
2
]
decompositions
=
items
[
3
:]
assert
char
not
in
data
data
[
char
]
=
{
"code_point"
:
code_point
,
"pronunciation"
:
pronunciation
,
"decompositions"
:
decompositions
}
return
data
def
shape_distance
(
self
,
char1
,
char2
,
safe
=
True
,
as_tree
=
False
):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.shape_distance('田', '由')
1
>>> c.shape_distance('牛', '午')
1
"""
assert
char1
in
self
.
data
assert
char2
in
self
.
data
def
safe_encode
(
decomp
):
tree
=
''
for
c
in
string_to_tree
(
decomp
):
if
c
not
in
self
.
safe
:
tree
+=
c
else
:
tree
+=
self
.
safe
[
c
]
return
tree
def
safe_encode_string
(
decomp
):
tree
=
''
for
c
in
decomp
:
if
c
not
in
self
.
safe
:
tree
+=
c
else
:
tree
+=
self
.
safe
[
c
]
return
tree
decomps_1
=
self
.
data
[
char1
][
"decompositions"
]
decomps_2
=
self
.
data
[
char2
][
"decompositions"
]
distance
=
1e5
if
as_tree
:
for
decomp1
in
decomps_1
:
for
decomp2
in
decomps_2
:
if
not
safe
:
ted
=
tree_edit_distance
(
string_to_tree
(
decomp1
),
string_to_tree
(
decomp2
))
else
:
ted
=
tree_edit_distance
(
safe_encode
(
decomp1
),
safe_encode
(
decomp2
))
distance
=
min
(
distance
,
ted
)
else
:
for
decomp1
in
decomps_1
:
for
decomp2
in
decomps_2
:
if
not
safe
:
ed
=
edit_distance
(
decomp1
,
decomp2
)
else
:
ed
=
edit_distance
(
safe_encode_string
(
decomp1
),
safe_encode_string
(
decomp2
))
distance
=
min
(
distance
,
ed
)
return
distance
def
pronunciation_distance
(
self
,
char1
,
char2
):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.pronunciation_distance('田', '由')
3.4
>>> c.pronunciation_distance('牛', '午')
2.6
"""
assert
char1
in
self
.
data
assert
char2
in
self
.
data
pronunciations1
=
self
.
data
[
char1
][
"pronunciation"
]
pronunciations2
=
self
.
data
[
char2
][
"pronunciation"
]
if
pronunciations1
[
0
]
==
'null'
or
pronunciations2
==
'null'
:
return
0.0
else
:
pronunciations1
=
pronunciations1
.
split
(
';'
)
# separate by lan
pronunciations2
=
pronunciations2
.
split
(
';'
)
# separate by lan
distance
=
0.0
count
=
0
for
pron_lan1
,
pron_lan2
in
zip
(
pronunciations1
,
pronunciations2
):
if
(
pron_lan1
==
'null'
)
or
(
pron_lan2
==
'null'
):
pass
else
:
distance_lan
=
1e5
for
p1
in
pron_lan1
.
split
(
','
):
for
p2
in
pron_lan2
.
split
(
','
):
distance_lan
=
min
(
distance_lan
,
edit_distance
(
p1
,
p2
))
distance
+=
distance_lan
count
+=
1
return
distance
/
count
@
staticmethod
def
load_dict
(
fname
):
data
=
{}
f
=
open
(
fname
,
'r'
,
encoding
=
'utf-8'
)
for
line
in
f
:
char
,
freq
=
line
.
strip
().
split
(
'
\t
'
)
assert
char
not
in
data
data
[
char
]
=
freq
return
data
def
similarity
(
self
,
char1
,
char2
,
weights
=
(
0.8
,
0.2
,
0.0
),
as_tree
=
False
):
"""
this function returns weighted similarity. When used in FASPell, each weight can only be 0 or 1.
"""
# assert char1 in self.char_dict
# assert char2 in self.char_dict
shape_w
,
sound_w
,
freq_w
=
weights
if
char1
in
self
.
char_dict
and
char2
in
self
.
char_dict
:
shape_sim
=
self
.
shape_similarity
(
char1
,
char2
,
as_tree
=
as_tree
)
sound_sim
=
self
.
pronunciation_similarity
(
char1
,
char2
)
freq_sim
=
1.0
-
self
.
char_dict
[
char2
]
/
len
(
self
.
char_dict
)
return
shape_sim
*
shape_w
+
sound_sim
*
sound_w
+
freq_sim
*
freq_w
else
:
return
0.0
def
shape_similarity
(
self
,
char1
,
char2
,
safe
=
True
,
as_tree
=
False
):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.shape_similarity('牛', '午')
0.8571428571428572
>>> c.shape_similarity('田', '由')
0.8888888888888888
"""
assert
char1
in
self
.
data
assert
char2
in
self
.
data
def
safe_encode
(
decomp
):
tree
=
''
for
c
in
string_to_tree
(
decomp
):
if
c
not
in
self
.
safe
:
tree
+=
c
else
:
tree
+=
self
.
safe
[
c
]
return
tree
def
safe_encode_string
(
decomp
):
tree
=
''
for
c
in
decomp
:
if
c
not
in
self
.
safe
:
tree
+=
c
else
:
tree
+=
self
.
safe
[
c
]
return
tree
decomps_1
=
self
.
data
[
char1
][
"decompositions"
]
decomps_2
=
self
.
data
[
char2
][
"decompositions"
]
similarity
=
0.0
if
as_tree
:
for
decomp1
in
decomps_1
:
for
decomp2
in
decomps_2
:
if
not
safe
:
ted
=
tree_edit_distance
(
string_to_tree
(
decomp1
),
string_to_tree
(
decomp2
))
else
:
ted
=
tree_edit_distance
(
safe_encode
(
decomp1
),
safe_encode
(
decomp2
))
normalized_ted
=
2
*
ted
/
(
len
(
decomp1
)
+
len
(
decomp2
)
+
ted
)
similarity
=
max
(
similarity
,
1
-
normalized_ted
)
else
:
for
decomp1
in
decomps_1
:
for
decomp2
in
decomps_2
:
if
not
safe
:
ed
=
edit_distance
(
decomp1
,
decomp2
)
else
:
ed
=
edit_distance
(
safe_encode_string
(
decomp1
),
safe_encode_string
(
decomp2
))
normalized_ed
=
ed
/
max
(
len
(
decomp1
),
len
(
decomp2
))
similarity
=
max
(
similarity
,
1
-
normalized_ed
)
return
similarity
def
pronunciation_similarity
(
self
,
char1
,
char2
):
"""
>>> c = CharFuncs('data/char_meta.txt')
>>> c.pronunciation_similarity('牛', '午')
0.27999999999999997
>>> c.pronunciation_similarity('由', '田')
0.09
"""
assert
char1
in
self
.
data
assert
char2
in
self
.
data
pronunciations1
=
self
.
data
[
char1
][
"pronunciation"
]
pronunciations2
=
self
.
data
[
char2
][
"pronunciation"
]
if
pronunciations1
[
0
]
==
'null'
or
pronunciations2
==
'null'
:
return
0.0
else
:
pronunciations1
=
pronunciations1
.
split
(
';'
)
# separate by lan
pronunciations2
=
pronunciations2
.
split
(
';'
)
# separate by lan
similarity
=
0.0
count
=
0
for
pron_lan1
,
pron_lan2
in
zip
(
pronunciations1
,
pronunciations2
):
if
(
pron_lan1
==
'null'
)
or
(
pron_lan2
==
'null'
):
pass
else
:
similarity_lan
=
0.0
for
p1
in
pron_lan1
.
split
(
','
):
for
p2
in
pron_lan2
.
split
(
','
):
tmp_sim
=
1
-
edit_distance
(
p1
,
p2
)
/
max
(
len
(
p1
),
len
(
p2
))
similarity_lan
=
max
(
similarity_lan
,
tmp_sim
)
similarity
+=
similarity_lan
count
+=
1
return
similarity
/
count
if
count
else
0
def
string_to_tree
(
string
):
"""
This function converts ids string to a string that can be used as a tree input to APTED.
Any Error raised by this function implies that the input string is invalid.
>>> string_to_tree('⿱⿱⿰丿㇏⿰丿㇏⿱⿰丿㇏⿰丿㇏') # 炎
'{⿱{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}{⿱{⿰{丿}{㇏}}{⿰{丿}{㇏}}}}'
>>> string_to_tree('⿱⿰丿㇏⿱一⿱⿻一丨一') # 全
'{⿱{⿰{丿}{㇏}}{⿱{一}{⿱{⿻{一}{丨}}{一}}}}'
>>> string_to_tree('⿱⿰丿㇏⿻⿱一⿱⿻一丨一丷') # 金
'{⿱{⿰{丿}{㇏}}{⿻{⿱{一}{⿱{⿻{一}{丨}}{一}}}{丷}}}'
>>> string_to_tree('⿻⿻⿻一丨一⿴⿱⿰丨𠃌一一') # 車
'{⿻{⿻{⿻{一}{丨}}{一}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
>>> string_to_tree('⿻⿻⿻一丨⿰丿㇏⿴⿱⿰丨𠃌一一') # 東
'{⿻{⿻{⿻{一}{丨}}{⿰{丿}{㇏}}}{⿴{⿱{⿰{丨}{𠃌}}{一}}{一}}}'
>>> string_to_tree('丿') # 丿
'{丿}'
>>> string_to_tree('⿻') # ⿻
'{⿻}'
"""
if
string
[
0
]
in
IDCS
and
len
(
string
)
!=
1
:
bracket_stack
=
[]
tree
=
[]
def
add_brackets
(
num
):
if
num
==
2
:
bracket_stack
.
extend
([
'}'
,
'{'
,
'}'
])
else
:
bracket_stack
.
extend
([
'}'
,
'{'
,
'}'
,
'{'
,
'}'
])
tree
.
append
(
'{'
)
global_just_put
=
'{'
for
c
in
string
:
tree
.
append
(
c
)
if
c
in
IDCS
:
assert
global_just_put
!=
'}'
add_brackets
(
IDCS
[
c
])
global_just_put
=
'{'
else
:
just_put
=
''
while
just_put
!=
'{'
and
bracket_stack
:
just_put
=
bracket_stack
.
pop
(
-
1
)
tree
.
append
(
just_put
)
global_just_put
=
just_put
res
=
''
.
join
(
tree
)
assert
res
[
-
1
]
==
'}'
else
:
assert
len
(
string
)
==
1
or
string
==
'null'
res
=
string
[
0
]
return
'{'
+
res
+
'}'
def
pinyin_map
(
standard_pinyin
):
"""
>>> pinyin_map('xuě')
'xue3'
>>> pinyin_map('xue')
'xue'
>>> pinyin_map('lǜ')
'lü4'
>>> pinyin_map('fá')
'fa2'
"""
tone
=
''
pinyin
=
''
assert
' '
not
in
standard_pinyin
for
c
in
standard_pinyin
:
if
c
in
PINYIN
:
pinyin
+=
PINYIN
[
c
][
0
]
assert
tone
==
''
tone
=
str
(
PINYIN
[
c
][
1
])
else
:
pinyin
+=
c
pinyin
+=
tone
return
pinyin
def
parse_args
():
usage
=
'
\n
1. You can compute character similarity by:
\n
'
\
'python char_sim.py 午 牛 年 千
\n
'
\
'
\n
'
\
'2. You can use ted in computing character similarity by:
\n
'
\
'python char_sim.py 午 牛 年 千 -t
\n
'
\
'
\n
'
parser
=
argparse
.
ArgumentParser
(
description
=
'A script to compute Chinese character (Kanji) similarity'
,
usage
=
usage
)
parser
.
add_argument
(
'multiargs'
,
nargs
=
'*'
,
type
=
str
,
default
=
None
,
help
=
'Chinese characters in question'
)
parser
.
add_argument
(
'--ted'
,
'-t'
,
action
=
"store_true"
,
default
=
False
,
help
=
'True=to use tree edit distence (TED)'
'False=to use string edit distance'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
parse_args
()
c
=
CharFuncs
(
'data/char_meta.txt'
)
if
not
args
.
ted
:
for
i
,
c1
in
enumerate
(
args
.
multiargs
):
for
c2
in
args
.
multiargs
[
i
:]:
if
c1
!=
c2
:
print
(
f
'For character pair (
{
c1
}
,
{
c2
}
):'
)
print
(
f
' v-sim =
{
c
.
shape_similarity
(
c1
,
c2
)
}
'
)
print
(
f
' p-sim =
{
c
.
pronunciation_similarity
(
c1
,
c2
)
}
\n
'
)
else
:
for
i
,
c1
in
enumerate
(
args
.
multiargs
):
for
c2
in
args
.
multiargs
[
i
:]:
if
c1
!=
c2
:
print
(
f
'For character pair (
{
c1
}
,
{
c2
}
):'
)
print
(
f
' v-sim =
{
c
.
shape_similarity
(
c1
,
c2
,
as_tree
=
True
)
}
'
)
print
(
f
' p-sim =
{
c
.
pronunciation_similarity
(
c1
,
c2
)
}
\n
'
)
\ No newline at end of file
opencompass/datasets/lawbench/utils/compare_m2_for_evaluation.py
0 → 100644
View file @
861942ab
import
argparse
from
collections
import
Counter
def
main
():
# Parse command line args
args
=
parse_args
()
# Open hypothesis and reference m2 files and split into chunks
hyp_m2
=
open
(
args
.
hyp
).
read
().
strip
().
split
(
"
\n\n
"
)[
args
.
start
:
args
.
end
]
if
args
.
start
is
not
None
and
args
.
end
is
not
None
else
open
(
args
.
hyp
).
read
().
strip
().
split
(
"
\n\n
"
)
ref_m2
=
open
(
args
.
ref
).
read
().
strip
().
split
(
"
\n\n
"
)[
args
.
start
:
args
.
end
]
if
args
.
start
is
not
None
and
args
.
end
is
not
None
else
open
(
args
.
ref
).
read
().
strip
().
split
(
"
\n\n
"
)
# Make sure they have the same number of sentences
assert
len
(
hyp_m2
)
==
len
(
ref_m2
),
print
(
len
(
hyp_m2
),
len
(
ref_m2
))
# Store global corpus level best counts here
best_dict
=
Counter
({
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
})
best_cats
=
{}
# Process each sentence
sents
=
zip
(
hyp_m2
,
ref_m2
)
for
sent_id
,
sent
in
enumerate
(
sents
):
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
src
=
sent
[
0
].
split
(
"
\n
"
)[
0
]
hyp_edits
=
simplify_edits
(
sent
[
0
],
args
.
max_answer_num
)
ref_edits
=
simplify_edits
(
sent
[
1
],
args
.
max_answer_num
)
# Process the edits for detection/correction based on args
hyp_dict
=
process_edits
(
hyp_edits
,
args
)
ref_dict
=
process_edits
(
ref_edits
,
args
)
if
args
.
reference_num
is
None
or
len
(
ref_dict
.
keys
())
==
args
.
reference_num
:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict
,
cat_dict
=
evaluate_edits
(
src
,
hyp_dict
,
ref_dict
,
best_dict
,
sent_id
,
args
)
# Merge these dicts with best_dict and best_cats
best_dict
+=
Counter
(
count_dict
)
best_cats
=
merge_dict
(
best_cats
,
cat_dict
)
# Print results
print_results
(
best_dict
,
best_cats
,
args
)
# Parse command line args
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Calculate F-scores for error detection and/or correction.
\n
"
"Flags let you evaluate at different levels of granularity."
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
usage
=
"%(prog)s [options] -hyp HYP -ref REF"
)
parser
.
add_argument
(
"-hyp"
,
help
=
"A hypothesis M2 file."
,
required
=
True
)
parser
.
add_argument
(
"-ref"
,
help
=
"A reference M2 file."
,
required
=
True
)
parser
.
add_argument
(
"--start"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--end"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--max_answer_num"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--reference_num"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"-b"
,
"--beta"
,
help
=
"Value of beta in F-score. (default: 0.5)"
,
default
=
0.5
,
type
=
float
)
parser
.
add_argument
(
"-v"
,
"--verbose"
,
help
=
"Print verbose output."
,
action
=
"store_true"
)
eval_type
=
parser
.
add_mutually_exclusive_group
()
eval_type
.
add_argument
(
"-dt"
,
help
=
"Evaluate Detection in terms of Tokens."
,
action
=
"store_true"
)
eval_type
.
add_argument
(
"-ds"
,
help
=
"Evaluate Detection in terms of Spans."
,
action
=
"store_true"
)
eval_type
.
add_argument
(
"-cs"
,
help
=
"Evaluate Correction in terms of Spans. (default)"
,
action
=
"store_true"
)
eval_type
.
add_argument
(
"-cse"
,
help
=
"Evaluate Correction in terms of Spans and Error types."
,
action
=
"store_true"
)
parser
.
add_argument
(
"-single"
,
help
=
"Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1"
,
action
=
"store_true"
)
parser
.
add_argument
(
"-multi"
,
help
=
"Only evaluate multi token edits; i.e. 2+:n or n:2+"
,
action
=
"store_true"
)
parser
.
add_argument
(
"-multi_hyp_avg"
,
help
=
"When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence."
,
action
=
"store_true"
)
# For IAA calculation
parser
.
add_argument
(
"-multi_hyp_max"
,
help
=
"When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence."
,
action
=
"store_true"
)
# For multiple hypotheses system evaluation
parser
.
add_argument
(
"-filt"
,
help
=
"Do not evaluate the specified error types."
,
nargs
=
"+"
,
default
=
[])
parser
.
add_argument
(
"-cat"
,
help
=
"Show error category scores.
\n
"
"1: Only show operation tier scores; e.g. R.
\n
"
"2: Only show main tier scores; e.g. NOUN.
\n
"
"3: Show all category scores; e.g. R:NOUN."
,
choices
=
[
1
,
2
,
3
],
type
=
int
)
args
=
parser
.
parse_args
()
return
args
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def
simplify_edits
(
sent
,
max_answer_num
):
out_edits
=
[]
# Get the edit lines from an m2 block.
edits
=
sent
.
split
(
"
\n
"
)
# Loop through the edits
for
edit
in
edits
:
# Preprocessing
if
edit
.
startswith
(
"A "
):
edit
=
edit
[
2
:].
split
(
"|||"
)
# Ignore "A " then split.
span
=
edit
[
0
].
split
()
start
=
int
(
span
[
0
])
end
=
int
(
span
[
1
])
cat
=
edit
[
1
]
cor
=
edit
[
2
].
replace
(
" "
,
""
)
coder
=
int
(
edit
[
-
1
])
out_edit
=
[
start
,
end
,
cat
,
cor
,
coder
]
out_edits
.
append
(
out_edit
)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if
max_answer_num
is
None
:
return
out_edits
elif
max_answer_num
==
1
:
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
==
0
]
elif
max_answer_num
==
2
:
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
in
[
0
,
1
]]
elif
max_answer_num
==
3
:
return
[
edit
for
edit
in
out_edits
if
edit
[
-
1
]
in
[
0
,
1
,
2
]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
def
process_edits
(
edits
,
args
):
coder_dict
=
{}
# Add an explicit noop edit if there are no edits.
if
not
edits
:
edits
=
[[
-
1
,
-
1
,
"noop"
,
"-NONE-"
,
0
]]
# Loop through the edits
for
edit
in
edits
:
# Name the edit elements for clarity
start
=
edit
[
0
]
end
=
edit
[
1
]
cat
=
edit
[
2
]
cor
=
edit
[
3
]
coder
=
edit
[
4
]
# Add the coder to the coder_dict if necessary
if
coder
not
in
coder_dict
:
coder_dict
[
coder
]
=
{}
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
if
not
args
.
dt
and
not
args
.
ds
and
cat
==
"UNK"
:
continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if
args
.
single
and
(
end
-
start
>=
2
or
len
(
cor
.
split
())
>=
2
):
continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if
args
.
multi
and
end
-
start
<
2
and
len
(
cor
.
split
())
<
2
:
continue
# 4. If there is a filter, ignore the specified error types
if
args
.
filt
and
cat
in
args
.
filt
:
continue
# Token Based Detection
if
args
.
dt
:
# Preserve noop edits.
if
start
==
-
1
:
if
(
start
,
start
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
start
)].
append
(
cat
)
else
:
coder_dict
[
coder
][(
start
,
start
)]
=
[
cat
]
# Insertions defined as affecting the token on the right
elif
start
==
end
and
start
>=
0
:
if
(
start
,
start
+
1
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
start
+
1
)].
append
(
cat
)
else
:
coder_dict
[
coder
][(
start
,
start
+
1
)]
=
[
cat
]
# Edit spans are split for each token in the range.
else
:
for
tok_id
in
range
(
start
,
end
):
if
(
tok_id
,
tok_id
+
1
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
tok_id
,
tok_id
+
1
)].
append
(
cat
)
else
:
coder_dict
[
coder
][(
tok_id
,
tok_id
+
1
)]
=
[
cat
]
# Span Based Detection
elif
args
.
ds
:
if
(
start
,
end
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
end
)].
append
(
cat
)
else
:
coder_dict
[
coder
][(
start
,
end
)]
=
[
cat
]
# Span Based Correction
else
:
# With error type classification
if
args
.
cse
:
if
(
start
,
end
,
cat
,
cor
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
end
,
cat
,
cor
)].
append
(
cat
)
else
:
coder_dict
[
coder
][(
start
,
end
,
cat
,
cor
)]
=
[
cat
]
# Without error type classification
else
:
if
(
start
,
end
,
cor
)
in
coder_dict
[
coder
].
keys
():
coder_dict
[
coder
][(
start
,
end
,
cor
)].
append
(
cat
)
else
:
coder_dict
[
coder
][(
start
,
end
,
cor
)]
=
[
cat
]
return
coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def
evaluate_edits
(
src
,
hyp_dict
,
ref_dict
,
best
,
sent_id
,
args
):
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp
,
best_fp
,
best_fn
,
best_f
,
best_hyp
,
best_ref
=
0
,
0
,
0
,
-
1
,
0
,
0
best_cat
=
{}
# skip not annotatable sentence
if
len
(
ref_dict
.
keys
())
==
1
:
ref_id
=
list
(
ref_dict
.
keys
())[
0
]
if
len
(
ref_dict
[
ref_id
].
keys
())
==
1
:
cat
=
list
(
ref_dict
[
ref_id
].
values
())[
0
][
0
]
if
cat
==
"NA"
:
best_dict
=
{
"tp"
:
best_tp
,
"fp"
:
best_fp
,
"fn"
:
best_fn
}
return
best_dict
,
best_cat
# Compare each hyp and ref combination
for
hyp_id
in
hyp_dict
.
keys
():
for
ref_id
in
ref_dict
.
keys
():
# Get the local counts for the current combination.
tp
,
fp
,
fn
,
cat_dict
=
compareEdits
(
hyp_dict
[
hyp_id
],
ref_dict
[
ref_id
])
# Compute the local sentence scores (for verbose output only)
loc_p
,
loc_r
,
loc_f
=
computeFScore
(
tp
,
fp
,
fn
,
args
.
beta
)
# Compute the global sentence scores
p
,
r
,
f
=
computeFScore
(
tp
+
best
[
"tp"
],
fp
+
best
[
"fp"
],
fn
+
best
[
"fn"
],
args
.
beta
)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
if
(
f
>
best_f
)
or
\
(
f
==
best_f
and
tp
>
best_tp
)
or
\
(
f
==
best_f
and
tp
==
best_tp
and
fp
<
best_fp
)
or
\
(
f
==
best_f
and
tp
==
best_tp
and
fp
==
best_fp
and
fn
<
best_fn
):
best_tp
,
best_fp
,
best_fn
=
tp
,
fp
,
fn
best_f
,
best_hyp
,
best_ref
=
f
,
hyp_id
,
ref_id
best_cat
=
cat_dict
# Verbose output
if
args
.
verbose
:
# Prepare verbose output edits.
hyp_verb
=
list
(
sorted
(
hyp_dict
[
hyp_id
].
keys
()))
ref_verb
=
list
(
sorted
(
ref_dict
[
ref_id
].
keys
()))
# Ignore noop edits
if
not
hyp_verb
or
hyp_verb
[
0
][
0
]
==
-
1
:
hyp_verb
=
[]
if
not
ref_verb
or
ref_verb
[
0
][
0
]
==
-
1
:
ref_verb
=
[]
# Print verbose info
print
(
'{:-^40}'
.
format
(
""
))
print
(
"SENTENCE "
+
str
(
sent_id
)
+
src
[
1
:])
print
(
'{:-^40}'
.
format
(
""
))
print
(
"SENTENCE "
+
str
(
sent_id
)
+
" - HYP "
+
str
(
hyp_id
)
+
" - REF "
+
str
(
ref_id
))
print
(
"HYPOTHESIS EDITS :"
,
hyp_verb
)
print
(
"REFERENCE EDITS :"
,
ref_verb
)
print
(
"Local TP/FP/FN :"
,
str
(
tp
),
str
(
fp
),
str
(
fn
))
print
(
"Local P/R/F"
+
str
(
args
.
beta
)
+
" :"
,
str
(
loc_p
),
str
(
loc_r
),
str
(
loc_f
))
print
(
"Global TP/FP/FN :"
,
str
(
tp
+
best
[
"tp"
]),
str
(
fp
+
best
[
"fp"
]),
str
(
fn
+
best
[
"fn"
]))
print
(
"Global P/R/F"
+
str
(
args
.
beta
)
+
" :"
,
str
(
p
),
str
(
r
),
str
(
f
))
# Verbose output: display the best hyp+ref combination
if
args
.
verbose
:
print
(
'{:-^40}'
.
format
(
""
))
print
(
"^^ HYP "
+
str
(
best_hyp
)
+
", REF "
+
str
(
best_ref
)
+
" chosen for sentence "
+
str
(
sent_id
))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict
=
{
"tp"
:
best_tp
,
"fp"
:
best_fp
,
"fn"
:
best_fn
}
return
best_dict
,
best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def
compareEdits
(
hyp_edits
,
ref_edits
):
tp
=
0
# True Positives
fp
=
0
# False Positives
fn
=
0
# False Negatives
cat_dict
=
{}
# {cat: [tp, fp, fn], ...}
for
h_edit
,
h_cats
in
hyp_edits
.
items
():
# noop hyp edits cannot be TP or FP
if
h_cats
[
0
]
==
"noop"
:
continue
# TRUE POSITIVES
if
h_edit
in
ref_edits
.
keys
():
# On occasion, multiple tokens at same span.
for
h_cat
in
ref_edits
[
h_edit
]:
# Use ref dict for TP
tp
+=
1
# Each dict value [TP, FP, FN]
if
h_cat
in
cat_dict
.
keys
():
cat_dict
[
h_cat
][
0
]
+=
1
else
:
cat_dict
[
h_cat
]
=
[
1
,
0
,
0
]
# FALSE POSITIVES
else
:
# On occasion, multiple tokens at same span.
for
h_cat
in
h_cats
:
fp
+=
1
# Each dict value [TP, FP, FN]
if
h_cat
in
cat_dict
.
keys
():
cat_dict
[
h_cat
][
1
]
+=
1
else
:
cat_dict
[
h_cat
]
=
[
0
,
1
,
0
]
for
r_edit
,
r_cats
in
ref_edits
.
items
():
# noop ref edits cannot be FN
if
r_cats
[
0
]
==
"noop"
:
continue
# FALSE NEGATIVES
if
r_edit
not
in
hyp_edits
.
keys
():
# On occasion, multiple tokens at same span.
for
r_cat
in
r_cats
:
fn
+=
1
# Each dict value [TP, FP, FN]
if
r_cat
in
cat_dict
.
keys
():
cat_dict
[
r_cat
][
2
]
+=
1
else
:
cat_dict
[
r_cat
]
=
[
0
,
0
,
1
]
return
tp
,
fp
,
fn
,
cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def
computeFScore
(
tp
,
fp
,
fn
,
beta
):
p
=
float
(
tp
)
/
(
tp
+
fp
)
if
fp
else
1.0
r
=
float
(
tp
)
/
(
tp
+
fn
)
if
fn
else
1.0
f
=
float
((
1
+
(
beta
**
2
))
*
p
*
r
)
/
(((
beta
**
2
)
*
p
)
+
r
)
if
p
+
r
else
0.0
return
round
(
p
,
4
),
round
(
r
,
4
),
round
(
f
,
4
)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def
merge_dict
(
dict1
,
dict2
):
for
cat
,
stats
in
dict2
.
items
():
if
cat
in
dict1
.
keys
():
dict1
[
cat
]
=
[
x
+
y
for
x
,
y
in
zip
(
dict1
[
cat
],
stats
)]
else
:
dict1
[
cat
]
=
stats
return
dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def
processCategories
(
cat_dict
,
setting
):
# Otherwise, do some processing.
proc_cat_dict
=
{}
for
cat
,
cnt
in
cat_dict
.
items
():
if
cat
==
"UNK"
:
proc_cat_dict
[
cat
]
=
cnt
continue
# M, U, R or UNK combined only.
if
setting
==
1
:
if
cat
[
0
]
in
proc_cat_dict
.
keys
():
proc_cat_dict
[
cat
[
0
]]
=
[
x
+
y
for
x
,
y
in
zip
(
proc_cat_dict
[
cat
[
0
]],
cnt
)]
else
:
proc_cat_dict
[
cat
[
0
]]
=
cnt
# Everything without M, U or R.
elif
setting
==
2
:
if
cat
[
2
:]
in
proc_cat_dict
.
keys
():
proc_cat_dict
[
cat
[
2
:]]
=
[
x
+
y
for
x
,
y
in
zip
(
proc_cat_dict
[
cat
[
2
:]],
cnt
)]
else
:
proc_cat_dict
[
cat
[
2
:]]
=
cnt
# All error category combinations
else
:
return
cat_dict
return
proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
def
print_results
(
best
,
best_cats
,
args
):
# Prepare output title.
if
args
.
dt
:
title
=
" Token-Based Detection "
elif
args
.
ds
:
title
=
" Span-Based Detection "
elif
args
.
cse
:
title
=
" Span-Based Correction + Classification "
else
:
title
=
" Span-Based Correction "
# Category Scores
if
args
.
cat
:
best_cats
=
processCategories
(
best_cats
,
args
.
cat
)
print
(
""
)
print
(
'{:=^66}'
.
format
(
title
))
print
(
"Category"
.
ljust
(
14
),
"TP"
.
ljust
(
8
),
"FP"
.
ljust
(
8
),
"FN"
.
ljust
(
8
),
"P"
.
ljust
(
8
),
"R"
.
ljust
(
8
),
"F"
+
str
(
args
.
beta
))
for
cat
,
cnts
in
sorted
(
best_cats
.
items
()):
cat_p
,
cat_r
,
cat_f
=
computeFScore
(
cnts
[
0
],
cnts
[
1
],
cnts
[
2
],
args
.
beta
)
print
(
cat
.
ljust
(
14
),
str
(
cnts
[
0
]).
ljust
(
8
),
str
(
cnts
[
1
]).
ljust
(
8
),
str
(
cnts
[
2
]).
ljust
(
8
),
str
(
cat_p
).
ljust
(
8
),
str
(
cat_r
).
ljust
(
8
),
cat_f
)
# Print the overall results.
print
(
""
)
print
(
'{:=^46}'
.
format
(
title
))
print
(
"
\t
"
.
join
([
"TP"
,
"FP"
,
"FN"
,
"Prec"
,
"Rec"
,
"F"
+
str
(
args
.
beta
)]))
print
(
"
\t
"
.
join
(
map
(
str
,
[
best
[
"tp"
],
best
[
"fp"
],
best
[
"fn"
]]
+
list
(
computeFScore
(
best
[
"tp"
],
best
[
"fp"
],
best
[
"fn"
],
args
.
beta
)))))
print
(
'{:=^46}'
.
format
(
""
))
print
(
""
)
if
__name__
==
"__main__"
:
# Run the program
main
()
opencompass/datasets/lawbench/utils/comprehension_scores.py
0 → 100644
View file @
861942ab
import
re
from
..utils.rc_f1
import
CJRCEvaluator
"""
given a target substring. find its all occurances in the string s
return the starting and ending index of every occurance
"""
def
__find_substring_starts
(
s
,
target
):
return
[(
m
.
start
(),
m
.
end
())
for
m
in
re
.
finditer
(
target
,
s
)]
"""
compute the reading comprehension F1 scores
hyps and refs are lists of hyposisis and reference strings
"""
def
compute_rc_f1
(
hyps
,
refs
):
scores
=
0
for
h
,
r
in
zip
(
hyps
,
refs
):
scores
+=
CJRCEvaluator
.
compute_f1
(
r
,
h
)
return
{
'score'
:
scores
/
len
(
hyps
)}
"""
compute the information extraction F1 scores
hyps and refs are lists of hyposisis and reference strings
entity_types: a set of all possible entity types
"""
def
compute_ie_f1
(
hyps
,
refs
,
entity_types
):
assert
(
len
(
hyps
)
==
len
(
refs
))
scores
,
abstentions
=
0
,
0
for
h
,
r
in
zip
(
hyps
,
refs
):
h
=
__extract_entities_pred
(
h
,
entity_types
)
r
=
__extract_entities_ref
(
r
)
if
r
==
{}:
scores
+=
1
if
h
==
{}
else
0
continue
if
h
==
{}:
abstentions
+=
1
intersected
=
[
CJRCEvaluator
.
compute_f1
(
r
[
etype
],
einstance
)
for
etype
,
einstance
in
h
.
items
()
if
etype
in
r
]
prec
=
sum
(
intersected
)
/
len
(
h
)
if
len
(
h
)
>
0
else
0
rec
=
sum
(
intersected
)
/
len
(
r
)
if
len
(
r
)
>
0
else
0
# print(prec, rec, intersected)
scores
+=
2
*
prec
*
rec
/
(
prec
+
rec
+
1e-10
)
return
{
'score'
:
scores
/
len
(
hyps
),
"anstention_rate"
:
abstentions
/
len
(
hyps
)}
def
__extract_entities_ref
(
ref
):
outputs
=
{}
if
ref
.
strip
()
==
''
:
return
outputs
for
seg
in
ref
.
split
(
';'
):
seg
=
seg
.
split
(
':'
)
outputs
[
seg
[
0
]]
=
seg
[
1
]
return
outputs
"""
extract entity type and instances from the model prediction
pred: string of model prediction
entity_types: a set of all possible entity types
"""
def
__extract_entities_pred
(
pred
,
entity_types
):
outputs
=
{}
for
etype
in
entity_types
:
occurances
=
__find_substring_starts
(
pred
,
etype
)
for
start
,
end
in
occurances
:
if
end
>=
(
len
(
pred
)
-
2
):
continue
if
pred
[
end
]
==
":"
or
pred
[
end
]
==
":"
:
einstance
=
re
.
split
(
"
\n
| "
,
pred
[
end
+
1
:].
strip
())[
0
].
strip
()
if
einstance
!=
'无'
and
einstance
!=
'未提及'
:
outputs
[
etype
]
=
einstance
return
outputs
opencompass/datasets/lawbench/utils/function_utils.py
0 → 100644
View file @
861942ab
from
rouge_chinese
import
Rouge
import
jieba
from
nltk.translate.gleu_score
import
corpus_gleu
def
compute_f1_two_sets
(
pred_set
,
gt_set
):
precision
=
len
(
pred_set
.
intersection
(
gt_set
))
/
len
(
pred_set
)
if
len
(
pred_set
)
>
0
else
0
recall
=
len
(
pred_set
.
intersection
(
gt_set
))
/
len
(
gt_set
)
if
len
(
gt_set
)
>
0
else
0
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0
return
f1
def
multi_choice_judge
(
prediction
,
option_list
,
answer_token
):
# a dict, key: letters in the option list, value: count of the letter in the prediction
count_dict
,
abstention
,
accuracy
=
{},
0
,
0
for
option
in
option_list
:
option_count
=
prediction
.
count
(
option
)
count_dict
[
option
]
=
1
if
option_count
>
0
else
0
# multiple occurrence of the same letter is counted as 1
if
sum
(
count_dict
.
values
())
==
0
:
abstention
=
1
# if the answer token is the only predicted token, the prediction is correct
elif
count_dict
[
answer_token
]
==
1
and
sum
(
count_dict
.
values
())
==
1
:
accuracy
=
1
return
{
"score"
:
accuracy
,
"abstention"
:
abstention
}
"""
compute the rouge score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def
compute_rouge
(
hyps
,
refs
):
assert
(
len
(
hyps
)
==
len
(
refs
))
hyps
=
[
' '
.
join
(
jieba
.
cut
(
h
))
for
h
in
hyps
]
hyps
=
[
h
if
h
.
strip
()
!=
""
else
"无内容"
for
h
in
hyps
]
refs
=
[
' '
.
join
(
jieba
.
cut
(
r
))
for
r
in
refs
]
return
Rouge
().
get_scores
(
hyps
,
refs
)
"""
compute the gleu score.
hyps and refs are lists of hyposisis and reference strings
empty predictions are replaces with 无内容
"""
def
compute_gleu
(
hyps
,
refs
):
assert
(
len
(
hyps
)
==
len
(
refs
))
hyps
=
[
' '
.
join
(
jieba
.
cut
(
h
))
for
h
in
hyps
]
hyps
=
[
h
if
h
.
strip
()
!=
""
else
"无内容"
for
h
in
hyps
]
refs
=
[[
' '
.
join
(
jieba
.
cut
(
r
))]
for
r
in
refs
]
return
corpus_gleu
(
refs
,
hyps
)
opencompass/datasets/lawbench/utils/modules/alignment.py
0 → 100644
View file @
861942ab
import
numpy
as
np
from
typing
import
List
,
Tuple
,
Dict
from
modules.tokenizer
import
Tokenizer
import
os
from
string
import
punctuation
REAL_PATH
=
os
.
path
.
split
(
os
.
path
.
realpath
(
__file__
))[
0
]
chinese_punct
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct
=
punctuation
punct
=
chinese_punct
+
english_punct
def
check_all_chinese
(
word
):
"""
判断一个单词是否全部由中文组成
:param word:
:return:
"""
return
all
([
'
\u4e00
'
<=
ch
<=
'
\u9fff
'
for
ch
in
word
])
def
read_cilin
():
"""
Cilin 詞林 is a thesaurus with semantic information
"""
# TODO -- fix this path
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
# ymliu@2023.5.30 fix the path
lines
=
open
(
os
.
path
.
join
(
project_dir
,
"data"
,
"cilin.txt"
),
"r"
,
encoding
=
"gbk"
).
read
().
strip
().
split
(
"
\n
"
)
semantic_dict
=
{}
semantic_classes
=
{}
for
line
in
lines
:
code
,
*
words
=
line
.
split
(
" "
)
for
word
in
words
:
semantic_dict
[
word
]
=
code
# make reverse dict
if
code
in
semantic_classes
:
semantic_classes
[
code
]
+=
words
else
:
semantic_classes
[
code
]
=
words
return
semantic_dict
,
semantic_classes
def
read_confusion
():
confusion_dict
=
{}
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
# ymliu@2023.5.30 fix the path
with
open
(
os
.
path
.
join
(
project_dir
,
"data"
,
"confusion_dict.txt"
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
for
line
in
f
:
li
=
line
.
rstrip
(
'
\n
'
).
split
(
" "
)
confusion_dict
[
li
[
0
]]
=
li
[
1
:]
return
confusion_dict
class
Alignment
:
"""
对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作
"""
def
__init__
(
self
,
semantic_dict
:
Dict
,
confusion_dict
:
Dict
,
granularity
:
str
=
"word"
,
)
->
None
:
"""
构造函数
:param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集
"""
self
.
insertion_cost
=
1
self
.
deletion_cost
=
1
self
.
semantic_dict
=
semantic_dict
self
.
confusion_dict
=
confusion_dict
# Because we use character level tokenization, this doesn't currently use POS
self
.
_open_pos
=
{}
# 如果是词级别,还可以利用词性是否相同来计算cost
self
.
granularity
=
granularity
# word-level or character-level
self
.
align_seqs
=
[]
def
__call__
(
self
,
src
:
List
[
Tuple
],
tgt
:
List
[
Tuple
],
verbose
:
bool
=
False
):
cost_matrix
,
oper_matrix
=
self
.
align
(
src
,
tgt
)
align_seq
=
self
.
get_cheapest_align_seq
(
oper_matrix
)
if
verbose
:
print
(
"========== Seg. and POS: =========="
)
print
(
src
)
print
(
tgt
)
print
(
"========== Cost Matrix =========="
)
print
(
cost_matrix
)
print
(
"========== Oper Matrix =========="
)
print
(
oper_matrix
)
print
(
"========== Alignment =========="
)
print
(
align_seq
)
print
(
"========== Results =========="
)
for
a
in
align_seq
:
print
(
a
[
0
],
src
[
a
[
1
]:
a
[
2
]],
tgt
[
a
[
3
]:
a
[
4
]])
return
align_seq
def
_get_semantic_class
(
self
,
word
):
"""
NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别)
"""
if
word
in
self
.
semantic_dict
:
code
=
self
.
semantic_dict
[
word
]
high
,
mid
,
low
=
code
[
0
],
code
[
1
],
code
[
2
:
4
]
return
high
,
mid
,
low
else
:
# unknown
return
None
@
staticmethod
def
_get_class_diff
(
a_class
,
b_class
):
"""
d == 3 for equivalent semantics
d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距
"""
d
=
sum
([
a
==
b
for
a
,
b
in
zip
(
a_class
,
b_class
)])
return
d
def
_get_semantic_cost
(
self
,
a
,
b
):
"""
计算基于语义信息的替换操作cost
:param a: 单词a的语义类别
:param b: 单词b的语义类别
:return: 替换编辑代价
"""
a_class
=
self
.
_get_semantic_class
(
a
)
b_class
=
self
.
_get_semantic_class
(
b
)
# unknown class, default to 1
if
a_class
is
None
or
b_class
is
None
:
return
4
elif
a_class
==
b_class
:
return
0
else
:
return
2
*
(
3
-
self
.
_get_class_diff
(
a_class
,
b_class
))
def
_get_pos_cost
(
self
,
a_pos
,
b_pos
):
"""
计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性
:param b_pos: 单词b的词性
:return: 替换编辑代价
"""
if
a_pos
==
b_pos
:
return
0
elif
a_pos
in
self
.
_open_pos
and
b_pos
in
self
.
_open_pos
:
return
0.25
else
:
return
0.499
def
_get_char_cost
(
self
,
a
,
b
,
pinyin_a
,
pinyin_b
):
"""
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost
"""
if
not
(
check_all_chinese
(
a
)
and
check_all_chinese
(
b
)):
return
0.5
if
len
(
a
)
>
len
(
b
):
a
,
b
=
b
,
a
pinyin_a
,
pinyin_b
=
pinyin_b
,
pinyin_a
if
a
==
b
:
return
0
else
:
return
self
.
_get_spell_cost
(
a
,
b
,
pinyin_a
,
pinyin_b
)
def
_get_spell_cost
(
self
,
a
,
b
,
pinyin_a
,
pinyin_b
):
"""
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a
:param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音
:return: 替换操作cost
"""
count
=
0
for
i
in
range
(
len
(
a
)):
for
j
in
range
(
len
(
b
)):
if
a
[
i
]
==
b
[
j
]
or
(
set
(
pinyin_a
)
&
set
(
pinyin_b
))
or
(
b
[
j
]
in
self
.
confusion_dict
.
keys
()
and
a
[
i
]
in
self
.
confusion_dict
[
b
[
j
]])
or
(
a
[
i
]
in
self
.
confusion_dict
.
keys
()
and
b
[
j
]
in
self
.
confusion_dict
[
a
[
i
]]):
count
+=
1
break
return
(
len
(
a
)
-
count
)
/
(
len
(
a
)
*
2
)
def
get_sub_cost
(
self
,
a_seg
,
b_seg
):
"""
Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
"""
if
a_seg
[
0
]
==
b_seg
[
0
]:
return
0
if
self
.
granularity
==
"word"
:
# 词级别可以额外利用词性信息
semantic_cost
=
self
.
_get_semantic_cost
(
a_seg
[
0
],
b_seg
[
0
])
/
6.0
pos_cost
=
self
.
_get_pos_cost
(
a_seg
[
1
],
b_seg
[
1
])
char_cost
=
self
.
_get_char_cost
(
a_seg
[
0
],
b_seg
[
0
],
a_seg
[
2
],
b_seg
[
2
])
return
semantic_cost
+
pos_cost
+
char_cost
else
:
# 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost
=
self
.
_get_semantic_cost
(
a_seg
[
0
],
b_seg
[
0
])
/
6.0
if
a_seg
[
0
]
in
punct
and
b_seg
[
0
]
in
punct
:
pos_cost
=
0.0
elif
a_seg
[
0
]
not
in
punct
and
b_seg
[
0
]
not
in
punct
:
pos_cost
=
0.25
else
:
pos_cost
=
0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost
=
self
.
_get_char_cost
(
a_seg
[
0
],
b_seg
[
0
],
a_seg
[
2
],
b_seg
[
2
])
return
semantic_cost
+
char_cost
+
pos_cost
def
align
(
self
,
src
:
List
[
Tuple
],
tgt
:
List
[
Tuple
]):
"""
Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题
"""
cost_matrix
=
np
.
zeros
((
len
(
src
)
+
1
,
len
(
tgt
)
+
1
))
# 编辑cost矩阵
oper_matrix
=
np
.
full
(
(
len
(
src
)
+
1
,
len
(
tgt
)
+
1
),
"O"
,
dtype
=
object
)
# 操作矩阵
# Fill in the edges
for
i
in
range
(
1
,
len
(
src
)
+
1
):
cost_matrix
[
i
][
0
]
=
cost_matrix
[
i
-
1
][
0
]
+
1
oper_matrix
[
i
][
0
]
=
[
"D"
]
for
j
in
range
(
1
,
len
(
tgt
)
+
1
):
cost_matrix
[
0
][
j
]
=
cost_matrix
[
0
][
j
-
1
]
+
1
oper_matrix
[
0
][
j
]
=
[
"I"
]
# Loop through the cost matrix
for
i
in
range
(
len
(
src
)):
for
j
in
range
(
len
(
tgt
)):
# Matches
if
src
[
i
][
0
]
==
tgt
[
j
][
0
]:
# 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix
[
i
+
1
][
j
+
1
]
=
cost_matrix
[
i
][
j
]
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"M"
]
# Non-matches
else
:
del_cost
=
cost_matrix
[
i
][
j
+
1
]
+
self
.
deletion_cost
# 由删除动作得到的总cost
ins_cost
=
cost_matrix
[
i
+
1
][
j
]
+
self
.
insertion_cost
# 由插入动作得到的总cost
sub_cost
=
cost_matrix
[
i
][
j
]
+
self
.
get_sub_cost
(
src
[
i
],
tgt
[
j
]
)
# 由替换动作得到的总cost
# Calculate transposition cost
# 计算移位操作的总cost
trans_cost
=
float
(
"inf"
)
k
=
1
while
(
i
-
k
>=
0
and
j
-
k
>=
0
and
cost_matrix
[
i
-
k
+
1
][
j
-
k
+
1
]
!=
cost_matrix
[
i
-
k
][
j
-
k
]
):
p1
=
sorted
([
a
[
0
]
for
a
in
src
][
i
-
k
:
i
+
1
])
p2
=
sorted
([
b
[
0
]
for
b
in
tgt
][
j
-
k
:
j
+
1
])
if
p1
==
p2
:
trans_cost
=
cost_matrix
[
i
-
k
][
j
-
k
]
+
k
break
k
+=
1
costs
=
[
trans_cost
,
sub_cost
,
ins_cost
,
del_cost
]
ind
=
costs
.
index
(
min
(
costs
))
cost_matrix
[
i
+
1
][
j
+
1
]
=
costs
[
ind
]
# ind = costs.index(costs[ind], ind+1)
for
idx
,
cost
in
enumerate
(
costs
):
if
cost
==
costs
[
ind
]:
if
idx
==
0
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"T"
+
str
(
k
+
1
)]
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"T"
+
str
(
k
+
1
))
elif
idx
==
1
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"S"
]
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"S"
)
elif
idx
==
2
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"I"
]
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"I"
)
else
:
if
oper_matrix
[
i
+
1
][
j
+
1
]
==
"O"
:
oper_matrix
[
i
+
1
][
j
+
1
]
=
[
"D"
]
else
:
oper_matrix
[
i
+
1
][
j
+
1
].
append
(
"D"
)
return
cost_matrix
,
oper_matrix
def
_dfs
(
self
,
i
,
j
,
align_seq_now
,
oper_matrix
,
strategy
=
"all"
):
"""
深度优先遍历,获取最小编辑距离相同的所有序列
"""
if
i
+
j
==
0
:
self
.
align_seqs
.
append
(
align_seq_now
)
else
:
ops
=
oper_matrix
[
i
][
j
]
# 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if
strategy
!=
"all"
:
ops
=
ops
[:
1
]
for
op
in
ops
:
if
op
in
{
"M"
,
"S"
}:
self
.
_dfs
(
i
-
1
,
j
-
1
,
align_seq_now
+
[(
op
,
i
-
1
,
i
,
j
-
1
,
j
)],
oper_matrix
,
strategy
)
elif
op
==
"D"
:
self
.
_dfs
(
i
-
1
,
j
,
align_seq_now
+
[(
op
,
i
-
1
,
i
,
j
,
j
)],
oper_matrix
,
strategy
)
elif
op
==
"I"
:
self
.
_dfs
(
i
,
j
-
1
,
align_seq_now
+
[(
op
,
i
,
i
,
j
-
1
,
j
)],
oper_matrix
,
strategy
)
else
:
k
=
int
(
op
[
1
:])
self
.
_dfs
(
i
-
k
,
j
-
k
,
align_seq_now
+
[(
op
,
i
-
k
,
i
,
j
-
k
,
j
)],
oper_matrix
,
strategy
)
def
get_cheapest_align_seq
(
self
,
oper_matrix
):
"""
回溯获得编辑距离最小的编辑序列
"""
self
.
align_seqs
=
[]
i
=
oper_matrix
.
shape
[
0
]
-
1
j
=
oper_matrix
.
shape
[
1
]
-
1
if
abs
(
i
-
j
)
>
10
:
self
.
_dfs
(
i
,
j
,
[],
oper_matrix
,
"first"
)
else
:
self
.
_dfs
(
i
,
j
,
[],
oper_matrix
,
"all"
)
final_align_seqs
=
[
seq
[::
-
1
]
for
seq
in
self
.
align_seqs
]
return
final_align_seqs
if
__name__
==
"__main__"
:
tokenizer
=
Tokenizer
(
"word"
)
semantic_dict
,
semantic_class
=
read_cilin
()
confusion_dict
=
read_confusion
()
alignment
=
Alignment
(
semantic_dict
,
confusion_dict
)
sents
=
[
"首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。"
.
replace
(
" "
,
""
),
"首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。"
.
replace
(
" "
,
""
)]
src
,
tgt
=
tokenizer
(
sents
)
alignment
(
src
,
tgt
,
verbose
=
True
)
\ No newline at end of file
opencompass/datasets/lawbench/utils/modules/annotator.py
0 → 100644
View file @
861942ab
from
typing
import
List
,
Tuple
from
modules.alignment
import
read_cilin
,
read_confusion
,
Alignment
from
modules.merger
import
Merger
from
modules.classifier
import
Classifier
class
Annotator
:
def
__init__
(
self
,
align
:
Alignment
,
merger
:
Merger
,
classifier
:
Classifier
,
granularity
:
str
=
"word"
,
strategy
:
str
=
"first"
):
self
.
align
=
align
self
.
merger
=
merger
self
.
classifier
=
classifier
self
.
granularity
=
granularity
self
.
strategy
=
strategy
@
classmethod
def
create_default
(
cls
,
granularity
:
str
=
"word"
,
strategy
:
str
=
"first"
):
"""
Default parameters used in the paper
"""
semantic_dict
,
semantic_class
=
read_cilin
()
confusion_dict
=
read_confusion
()
align
=
Alignment
(
semantic_dict
,
confusion_dict
,
granularity
)
merger
=
Merger
(
granularity
)
classifier
=
Classifier
(
granularity
)
return
cls
(
align
,
merger
,
classifier
,
granularity
,
strategy
)
def
__call__
(
self
,
src
:
List
[
Tuple
],
tgt
:
List
[
Tuple
],
annotator_id
:
int
=
0
,
verbose
:
bool
=
False
):
"""
Align sentences and annotate them with error type information
"""
src_tokens
=
[
x
[
0
]
for
x
in
src
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
src_str
=
""
.
join
(
src_tokens
)
tgt_str
=
""
.
join
(
tgt_tokens
)
# convert to text form
annotations_out
=
[
"S "
+
" "
.
join
(
src_tokens
)
+
"
\n
"
]
if
tgt_str
==
"没有错误"
or
src_str
==
tgt_str
:
# Error Free Case
annotations_out
.
append
(
f
"T
{
annotator_id
}
没有错误
\n
"
)
cors
=
[
tgt_str
]
op
,
toks
,
inds
=
"noop"
,
"-NONE-"
,
(
-
1
,
-
1
)
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
annotations_out
.
append
(
a_str
)
elif
tgt_str
==
"无法标注"
:
# Not Annotatable Case
annotations_out
.
append
(
f
"T
{
annotator_id
}
无法标注
\n
"
)
cors
=
[
tgt_str
]
op
,
toks
,
inds
=
"NA"
,
"-NONE-"
,
(
-
1
,
-
1
)
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
annotations_out
.
append
(
a_str
)
else
:
# Other
align_objs
=
self
.
align
(
src
,
tgt
)
edit_objs
=
[]
align_idx
=
0
if
self
.
strategy
==
"first"
:
align_objs
=
align_objs
[:
1
]
for
align_obj
in
align_objs
:
edits
=
self
.
merger
(
align_obj
,
src
,
tgt
,
verbose
)
if
edits
not
in
edit_objs
:
edit_objs
.
append
(
edits
)
annotations_out
.
append
(
f
"T
{
annotator_id
}
-A
{
align_idx
}
"
+
" "
.
join
(
tgt_tokens
)
+
"
\n
"
)
align_idx
+=
1
cors
=
self
.
classifier
(
src
,
tgt
,
edits
,
verbose
)
# annotations_out = []
for
cor
in
cors
:
op
,
toks
,
inds
=
cor
.
op
,
cor
.
toks
,
cor
.
inds
a_str
=
f
"A
{
inds
[
0
]
}
{
inds
[
1
]
}
|||
{
op
}
|||
{
toks
}
|||REQUIRED|||-NONE-|||
{
annotator_id
}
\n
"
annotations_out
.
append
(
a_str
)
annotations_out
.
append
(
"
\n
"
)
return
annotations_out
,
cors
opencompass/datasets/lawbench/utils/modules/classifier.py
0 → 100644
View file @
861942ab
from
char_smi
import
CharFuncs
from
collections
import
namedtuple
from
pypinyin
import
pinyin
,
Style
import
os
Correction
=
namedtuple
(
"Correction"
,
[
"op"
,
"toks"
,
"inds"
,
],
)
file_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
char_smi
=
CharFuncs
(
os
.
path
.
join
(
file_path
.
replace
(
"modules"
,
""
),
'data/char_meta.txt'
))
def
check_spell_error
(
src_span
:
str
,
tgt_span
:
str
,
threshold
:
float
=
0.8
)
->
bool
:
if
len
(
src_span
)
!=
len
(
tgt_span
):
return
False
src_chars
=
[
ch
for
ch
in
src_span
]
tgt_chars
=
[
ch
for
ch
in
tgt_span
]
if
sorted
(
src_chars
)
==
sorted
(
tgt_chars
):
# 词内部字符异位
return
True
for
src_char
,
tgt_char
in
zip
(
src_chars
,
tgt_chars
):
if
src_char
!=
tgt_char
:
if
src_char
not
in
char_smi
.
data
or
tgt_char
not
in
char_smi
.
data
:
return
False
v_sim
=
char_smi
.
shape_similarity
(
src_char
,
tgt_char
)
p_sim
=
char_smi
.
pronunciation_similarity
(
src_char
,
tgt_char
)
if
v_sim
+
p_sim
<
threshold
and
not
(
set
(
pinyin
(
src_char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])
&
set
(
pinyin
(
tgt_char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])):
return
False
return
True
class
Classifier
:
"""
错误类型分类器
"""
def
__init__
(
self
,
granularity
:
str
=
"word"
):
self
.
granularity
=
granularity
@
staticmethod
def
get_pos_type
(
pos
):
if
pos
in
{
"n"
,
"nd"
}:
return
"NOUN"
if
pos
in
{
"nh"
,
"ni"
,
"nl"
,
"ns"
,
"nt"
,
"nz"
}:
return
"NOUN-NE"
if
pos
in
{
"v"
}:
return
"VERB"
if
pos
in
{
"a"
,
"b"
}:
return
"ADJ"
if
pos
in
{
"c"
}:
return
"CONJ"
if
pos
in
{
"r"
}:
return
"PRON"
if
pos
in
{
"d"
}:
return
"ADV"
if
pos
in
{
"u"
}:
return
"AUX"
# if pos in {"k"}: # TODO 后缀词比例太少,暂且分入其它
# return "SUFFIX"
if
pos
in
{
"m"
}:
return
"NUM"
if
pos
in
{
"p"
}:
return
"PREP"
if
pos
in
{
"q"
}:
return
"QUAN"
if
pos
in
{
"wp"
}:
return
"PUNCT"
return
"OTHER"
def
__call__
(
self
,
src
,
tgt
,
edits
,
verbose
:
bool
=
False
):
"""
为编辑操作划分错误类型
:param src: 错误句子信息
:param tgt: 正确句子信息
:param edits: 编辑操作
:param verbose: 是否打印信息
:return: 划分完错误类型后的编辑操作
"""
results
=
[]
src_tokens
=
[
x
[
0
]
for
x
in
src
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
for
edit
in
edits
:
error_type
=
edit
[
0
]
src_span
=
" "
.
join
(
src_tokens
[
edit
[
1
]:
edit
[
2
]])
tgt_span
=
" "
.
join
(
tgt_tokens
[
edit
[
3
]:
edit
[
4
]])
# print(tgt_span)
cor
=
None
if
error_type
[
0
]
==
"T"
:
cor
=
Correction
(
"W"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
elif
error_type
[
0
]
==
"D"
:
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
edit
[
2
]
-
edit
[
1
]
>
1
:
# 词组冗余暂时分为OTHER
cor
=
Correction
(
"R:OTHER"
,
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
else
:
pos
=
self
.
get_pos_type
(
src
[
edit
[
1
]][
1
])
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
cor
=
Correction
(
"R:{:s}"
.
format
(
pos
),
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
else
:
# 字级别可以只需要根据操作划分类型即可
cor
=
Correction
(
"R"
,
"-NONE-"
,
(
edit
[
1
],
edit
[
2
]))
elif
error_type
[
0
]
==
"I"
:
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
edit
[
4
]
-
edit
[
3
]
>
1
:
# 词组丢失暂时分为OTHER
cor
=
Correction
(
"M:OTHER"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
pos
=
self
.
get_pos_type
(
tgt
[
edit
[
3
]][
1
])
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
cor
=
Correction
(
"M:{:s}"
.
format
(
pos
),
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
# 字级别可以只需要根据操作划分类型即可
cor
=
Correction
(
"M"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
elif
error_type
[
0
]
==
"S"
:
if
self
.
granularity
==
"word"
:
# 词级别可以细分错误类型
if
check_spell_error
(
src_span
.
replace
(
" "
,
""
),
tgt_span
.
replace
(
" "
,
""
)):
cor
=
Correction
(
"S:SPELL"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
# Todo 暂且不单独区分命名实体拼写错误
# if edit[4] - edit[3] > 1:
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
# else:
# pos = self.get_pos_type(tgt[edit[3]][1])
# if pos == "NOUN-NE": # 命名实体拼写有误
# cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
# else: # 普通词语拼写有误
# cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
else
:
if
edit
[
4
]
-
edit
[
3
]
>
1
:
# 词组被替换暂时分为OTHER
cor
=
Correction
(
"S:OTHER"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
pos
=
self
.
get_pos_type
(
tgt
[
edit
[
3
]][
1
])
pos
=
"NOUN"
if
pos
==
"NOUN-NE"
else
pos
pos
=
"MC"
if
tgt_span
==
"[缺失成分]"
else
pos
cor
=
Correction
(
"S:{:s}"
.
format
(
pos
),
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
else
:
# 字级别可以只需要根据操作划分类型即可
cor
=
Correction
(
"S"
,
tgt_span
,
(
edit
[
1
],
edit
[
2
]))
results
.
append
(
cor
)
if
verbose
:
print
(
"========== Corrections =========="
)
for
cor
in
results
:
print
(
"Type: {:s}, Position: {:d} -> {:d}, Target: {:s}"
.
format
(
cor
.
op
,
cor
.
inds
[
0
],
cor
.
inds
[
1
],
cor
.
toks
))
return
results
# print(pinyin("朝", style=Style.NORMAL))
opencompass/datasets/lawbench/utils/modules/merger.py
0 → 100644
View file @
861942ab
from
itertools
import
groupby
from
string
import
punctuation
from
typing
import
List
from
modules.tokenizer
import
Tokenizer
from
modules.alignment
import
Alignment
,
read_cilin
,
read_confusion
import
Levenshtein
class
Merger
:
"""
合并编辑操作,从Token-Level转换为Span-Level
"""
def
__init__
(
self
,
granularity
:
str
=
"word"
,
merge
:
bool
=
False
):
chinese_punct
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self
.
punctuation
=
punctuation
+
chinese_punct
self
.
not_merge_token
=
[
punct
for
punct
in
self
.
punctuation
]
self
.
granularity
=
granularity
self
.
merge
=
merge
@
staticmethod
def
_merge_edits
(
seq
,
tag
=
"X"
):
if
seq
:
return
[(
tag
,
seq
[
0
][
1
],
seq
[
-
1
][
2
],
seq
[
0
][
3
],
seq
[
-
1
][
4
])]
else
:
return
seq
@
staticmethod
def
_check_revolve
(
span_a
,
span_b
):
span_a
=
span_a
+
span_a
return
span_b
in
span_a
def
_process_seq
(
self
,
seq
,
src_tokens
,
tgt_tokens
):
if
len
(
seq
)
<=
1
:
return
seq
ops
=
[
op
[
0
]
for
op
in
seq
]
if
set
(
ops
)
==
{
"D"
}
or
set
(
ops
)
==
{
"I"
}:
return
self
.
_merge_edits
(
seq
,
set
(
ops
).
pop
())
if
set
(
ops
)
==
{
"D"
,
"I"
}
or
set
(
ops
)
==
{
"I"
,
"D"
}:
# do not merge this pattern_from_qua.txt
return
seq
if
set
(
ops
)
==
{
"S"
}:
if
self
.
granularity
==
"word"
:
return
seq
else
:
return
self
.
_merge_edits
(
seq
,
"S"
)
if
set
(
ops
)
==
{
"M"
}:
return
self
.
_merge_edits
(
seq
,
"M"
)
return
self
.
_merge_edits
(
seq
,
"S"
)
def
__call__
(
self
,
align_obj
,
src
:
List
,
tgt
:
List
,
verbose
:
bool
=
False
):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens
=
[
x
[
0
]
for
x
in
src
]
tgt_tokens
=
[
x
[
0
]
for
x
in
tgt
]
edits
=
[]
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for
op
,
group
in
groupby
(
align_obj
,
lambda
x
:
x
[
0
][
0
]
if
x
[
0
][
0
]
in
{
"M"
,
"T"
}
else
False
,
):
group
=
list
(
group
)
# T is always split TODO: Evaluate this
if
op
==
"T"
:
for
seq
in
group
:
edits
.
append
(
seq
)
# Process D, I and S subsequence
else
:
# Turn the processed sequence into edits
processed
=
self
.
_process_seq
(
group
,
src_tokens
,
tgt_tokens
)
for
seq
in
processed
:
edits
.
append
(
seq
)
filtered_edits
=
[]
i
=
0
while
i
<
len
(
edits
):
e1
=
edits
[
i
][
0
][
0
]
if
i
<
len
(
edits
)
-
2
:
e2
=
edits
[
i
+
1
][
0
][
0
]
e3
=
edits
[
i
+
2
][
0
][
0
]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if
e1
==
"S"
and
e2
==
"M"
and
e3
==
"S"
:
w1
=
""
.
join
(
src_tokens
[
edits
[
i
][
1
]:
edits
[
i
][
2
]])
w2
=
""
.
join
(
tgt_tokens
[
edits
[
i
][
3
]:
edits
[
i
][
4
]])
w3
=
""
.
join
(
src_tokens
[
edits
[
i
+
2
][
1
]:
edits
[
i
+
2
][
2
]])
w4
=
""
.
join
(
tgt_tokens
[
edits
[
i
+
2
][
3
]:
edits
[
i
+
2
][
4
]])
if
min
([
len
(
w1
),
len
(
w2
),
len
(
w3
),
len
(
w4
)])
==
1
:
if
w1
==
w4
and
w2
==
w3
:
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
i
+=
3
else
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
else
:
if
Levenshtein
.
distance
(
w1
,
w4
)
<=
1
and
Levenshtein
.
distance
(
w2
,
w3
)
<=
1
:
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
i
+=
3
else
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif
(
e1
==
"D"
and
(
e2
==
"M"
or
e2
.
startswith
(
"T"
))
and
e3
==
"I"
)
or
(
e1
==
"I"
and
(
e2
==
"M"
or
e2
.
startswith
(
"T"
))
and
e3
==
"D"
):
if
e1
==
"D"
:
delete_token
=
src_tokens
[
edits
[
i
][
1
]:
edits
[
i
][
2
]]
insert_token
=
tgt_tokens
[
edits
[
i
+
2
][
3
]:
edits
[
i
+
2
][
4
]]
else
:
delete_token
=
src_tokens
[
edits
[
i
+
2
][
1
]:
edits
[
i
+
2
][
2
]]
insert_token
=
tgt_tokens
[
edits
[
i
][
3
]:
edits
[
i
][
4
]]
a
,
b
=
""
.
join
(
delete_token
),
""
.
join
(
insert_token
)
if
len
(
a
)
<
len
(
b
):
a
,
b
=
b
,
a
if
a
not
in
self
.
punctuation
and
b
not
in
self
.
punctuation
and
len
(
a
)
-
len
(
b
)
<=
1
:
if
len
(
b
)
==
1
:
if
a
==
b
:
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
i
+=
3
else
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
else
:
if
Levenshtein
.
distance
(
a
,
b
)
<=
1
or
(
len
(
a
)
==
len
(
b
)
and
self
.
_check_revolve
(
a
,
b
)):
group
=
[
edits
[
i
],
edits
[
i
+
1
],
edits
[
i
+
2
]]
processed
=
self
.
_merge_edits
(
group
,
"T"
+
str
(
edits
[
i
+
2
][
2
]
-
edits
[
i
][
1
]))
for
seq
in
processed
:
filtered_edits
.
append
(
seq
)
i
+=
3
else
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
else
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
else
:
if
e1
!=
"M"
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
else
:
if
e1
!=
"M"
:
filtered_edits
.
append
(
edits
[
i
])
i
+=
1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter
=
[]
for
edit
in
filtered_edits
:
# 避免因为分词错误导致的mismatch现象
span1
=
""
.
join
(
src_tokens
[
edit
[
1
]
:
edit
[
2
]])
span2
=
""
.
join
(
tgt_tokens
[
edit
[
3
]
:
edit
[
4
]])
if
span1
!=
span2
:
if
edit
[
0
]
==
"S"
:
b
=
True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str
=
""
tmp_new_start_1
=
edit
[
1
]
for
i
in
range
(
edit
[
1
],
edit
[
2
]):
if
not
span2
.
startswith
(
common_str
+
src_tokens
[
i
]):
break
common_str
+=
src_tokens
[
i
]
tmp_new_start_1
=
i
+
1
new_start_1
,
new_start_2
=
edit
[
1
],
edit
[
3
]
if
common_str
:
tmp_str
=
""
for
i
in
range
(
edit
[
3
],
edit
[
4
]):
tmp_str
+=
tgt_tokens
[
i
]
if
tmp_str
==
common_str
:
new_start_1
,
new_start_2
=
tmp_new_start_1
,
i
+
1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b
=
False
break
elif
len
(
tmp_str
)
>
len
(
common_str
):
break
# 尾部有重叠
common_str
=
""
new_end_1
,
new_end_2
=
edit
[
2
],
edit
[
4
]
tmp_new_end_1
=
edit
[
2
]
for
i
in
reversed
(
range
(
new_start_1
,
edit
[
2
])):
if
not
span2
.
endswith
(
src_tokens
[
i
]
+
common_str
):
break
common_str
=
src_tokens
[
i
]
+
common_str
tmp_new_end_1
=
i
if
common_str
:
tmp_str
=
""
for
i
in
reversed
(
range
(
new_start_2
,
edit
[
4
])):
tmp_str
=
tgt_tokens
[
i
]
+
tmp_str
if
tmp_str
==
common_str
:
new_end_1
,
new_end_2
=
tmp_new_end_1
,
i
b
=
False
break
elif
len
(
tmp_str
)
>
len
(
common_str
):
break
if
b
:
second_filter
.
append
(
edit
)
else
:
if
new_start_1
==
new_end_1
:
new_edit
=
(
"I"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
elif
new_start_2
==
new_end_2
:
new_edit
=
(
"D"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
else
:
new_edit
=
(
"S"
,
new_start_1
,
new_end_1
,
new_start_2
,
new_end_2
)
second_filter
.
append
(
new_edit
)
else
:
second_filter
.
append
(
edit
)
if
verbose
:
print
(
"========== Parallels =========="
)
print
(
""
.
join
(
src_tokens
))
print
(
""
.
join
(
tgt_tokens
))
print
(
"========== Results =========="
)
for
edit
in
second_filter
:
op
=
edit
[
0
]
s
=
" "
.
join
(
src_tokens
[
edit
[
1
]:
edit
[
2
]])
t
=
" "
.
join
(
tgt_tokens
[
edit
[
3
]:
edit
[
4
]])
print
(
f
"
{
op
}
:
\t
{
s
}
\t
-->
\t
{
t
}
"
)
print
(
"========== Infos =========="
)
print
(
str
(
src
))
print
(
str
(
tgt
))
return
second_filter
if
__name__
==
"__main__"
:
tokenizer
=
Tokenizer
(
"char"
)
semantic_dict
,
semantic_class
=
read_cilin
()
confusion_dict
=
read_confusion
()
alignment
=
Alignment
(
semantic_dict
,
confusion_dict
)
sents
=
[
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。"
.
replace
(
" "
,
""
),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。"
.
replace
(
" "
,
""
)]
src
,
tgt
=
tokenizer
(
sents
)
align_obj
=
alignment
(
src
,
tgt
)
m
=
Merger
()
m
(
align_obj
,
src
,
tgt
,
verbose
=
True
)
\ No newline at end of file
opencompass/datasets/lawbench/utils/modules/tokenization.py
0 → 100644
View file @
861942ab
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
unicodedata
import
six
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
def
convert_by_vocab
(
vocab
,
items
):
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
if
item
not
in
vocab
:
print
(
"warning: %s not in vocab"
%
item
)
item
=
"[UNK]"
output
.
append
(
vocab
[
item
])
return
output
def
convert_tokens_to_ids
(
vocab
,
tokens
):
return
convert_by_vocab
(
vocab
,
tokens
)
def
convert_ids_to_tokens
(
inv_vocab
,
ids
):
return
convert_by_vocab
(
inv_vocab
,
ids
)
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
# output_tokens.append(self.unk_token)
output_tokens
.
append
(
token
)
# keep the UNK token
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
\ No newline at end of file
opencompass/datasets/lawbench/utils/modules/tokenizer.py
0 → 100644
View file @
861942ab
from
ltp
import
LTP
from
typing
import
List
from
pypinyin
import
pinyin
,
Style
,
lazy_pinyin
import
torch
import
os
import
functools
class
Tokenizer
:
"""
分词器
"""
def
__init__
(
self
,
granularity
:
str
=
"word"
,
device
:
str
=
"cpu"
,
segmented
:
bool
=
False
,
bpe
:
bool
=
False
,
)
->
None
:
"""
构造函数
:param mode: 分词模式,可选级别:字级别(char)、词级别(word)
"""
self
.
ltp
=
None
if
granularity
==
"word"
:
self
.
ltp
=
LTP
(
device
=
torch
.
device
(
device
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
))
self
.
ltp
.
add_words
(
words
=
[
"[缺失成分]"
],
max_window
=
6
)
self
.
segmented
=
segmented
self
.
granularity
=
granularity
if
self
.
granularity
==
"word"
:
self
.
tokenizer
=
self
.
split_word
elif
self
.
granularity
==
"char"
:
self
.
tokenizer
=
functools
.
partial
(
self
.
split_char
,
bpe
=
bpe
)
else
:
raise
NotImplementedError
def
__repr__
(
self
)
->
str
:
return
"{:s}
\n
Mode:{:s}
\n
}"
.
format
(
str
(
self
.
__class__
.
__name__
),
self
.
mode
)
def
__call__
(
self
,
input_strings
:
List
[
str
]
)
->
List
:
"""
分词函数
:param input_strings: 需要分词的字符串列表
:return: 分词后的结果列表,由元组组成,元组为(token,pos_tag,pinyin)的形式
"""
if
not
self
.
segmented
:
input_strings
=
[
""
.
join
(
s
.
split
(
" "
))
for
s
in
input_strings
]
results
=
self
.
tokenizer
(
input_strings
)
return
results
def
split_char
(
self
,
input_strings
:
List
[
str
],
bpe
=
False
)
->
List
:
"""
分字函数
:param input_strings: 需要分字的字符串
:return: 分字结果
"""
if
bpe
:
from
.
import
tokenization
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
os
.
path
.
join
(
project_dir
,
"data"
,
"chinese_vocab.txt"
),
do_lower_case
=
False
)
results
=
[]
for
input_string
in
input_strings
:
if
not
self
.
segmented
:
# 如果没有被分字,就按照每个字符隔开(不考虑英文标点的特殊处理,也不考虑BPE),否则遵循原分字结果
segment_string
=
" "
.
join
([
char
for
char
in
input_string
]
if
not
bpe
else
tokenizer
.
tokenize
(
input_string
))
else
:
segment_string
=
input_string
# print(segment_string)
segment_string
=
segment_string
.
replace
(
"[ 缺 失 成 分 ]"
,
"[缺失成分]"
).
split
(
" "
)
# 缺失成分当成一个单独的token
results
.
append
([(
char
,
"unk"
,
pinyin
(
char
,
style
=
Style
.
NORMAL
,
heteronym
=
True
)[
0
])
for
char
in
segment_string
])
return
results
def
split_word
(
self
,
input_strings
:
List
[
str
])
->
List
:
"""
分词函数
:param input_strings: 需要分词的字符串
:return: 分词结果
"""
if
self
.
segmented
:
seg
,
hidden
=
self
.
ltp
.
seg
([
input_string
.
split
(
" "
)
for
input_string
in
input_strings
],
is_preseged
=
True
)
else
:
seg
,
hidden
=
self
.
ltp
.
seg
(
input_strings
)
pos
=
self
.
ltp
.
pos
(
hidden
)
result
=
[]
for
s
,
p
in
zip
(
seg
,
pos
):
pinyin
=
[
lazy_pinyin
(
word
)
for
word
in
s
]
result
.
append
(
list
(
zip
(
s
,
p
,
pinyin
)))
return
result
if
__name__
==
"__main__"
:
tokenizer
=
Tokenizer
(
"word"
)
print
(
tokenizer
([
"LAC是个优秀的分词工具"
,
"百度是一家高科技公司"
]))
opencompass/datasets/lawbench/utils/parallel_to_m2.py
0 → 100644
View file @
861942ab
import
os
from
modules.annotator
import
Annotator
from
modules.tokenizer
import
Tokenizer
import
argparse
from
collections
import
Counter
from
tqdm
import
tqdm
import
torch
from
collections
import
defaultdict
from
multiprocessing
import
Pool
from
opencc
import
OpenCC
import
timeout_decorator
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
annotator
,
sentence_to_tokenized
=
None
,
None
cc
=
OpenCC
(
"t2s"
)
@
timeout_decorator
.
timeout
(
10
)
def
annotate_with_time_out
(
line
):
"""
:param line:
:return:
"""
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
source
=
sent_list
[
0
]
if
args
.
segmented
:
source
=
source
.
strip
()
else
:
source
=
""
.
join
(
source
.
strip
().
split
())
output_str
=
""
for
idx
,
target
in
enumerate
(
sent_list
[
1
:]):
try
:
if
args
.
segmented
:
target
=
target
.
strip
()
else
:
target
=
""
.
join
(
target
.
strip
().
split
())
if
not
args
.
no_simplified
:
target
=
cc
.
convert
(
target
)
source_tokenized
,
target_tokenized
=
sentence_to_tokenized
[
source
],
sentence_to_tokenized
[
target
]
out
,
cors
=
annotator
(
source_tokenized
,
target_tokenized
,
idx
)
if
idx
==
0
:
output_str
+=
""
.
join
(
out
[:
-
1
])
else
:
output_str
+=
""
.
join
(
out
[
1
:
-
1
])
except
Exception
:
raise
Exception
return
output_str
def
annotate
(
line
):
"""
:param line:
:return:
"""
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
source
=
sent_list
[
0
]
if
args
.
segmented
:
source
=
source
.
strip
()
else
:
source
=
""
.
join
(
source
.
strip
().
split
())
output_str
=
""
for
idx
,
target
in
enumerate
(
sent_list
[
1
:]):
try
:
if
args
.
segmented
:
target
=
target
.
strip
()
else
:
target
=
""
.
join
(
target
.
strip
().
split
())
if
not
args
.
no_simplified
:
target
=
cc
.
convert
(
target
)
source_tokenized
,
target_tokenized
=
sentence_to_tokenized
[
source
],
sentence_to_tokenized
[
target
]
out
,
cors
=
annotator
(
source_tokenized
,
target_tokenized
,
idx
)
if
idx
==
0
:
output_str
+=
""
.
join
(
out
[:
-
1
])
else
:
output_str
+=
""
.
join
(
out
[
1
:
-
1
])
except
Exception
:
raise
Exception
return
output_str
def
firsttime_process
(
args
):
tokenizer
=
Tokenizer
(
args
.
granularity
,
args
.
device
,
args
.
segmented
,
args
.
bpe
)
global
annotator
,
sentence_to_tokenized
annotator
=
Annotator
.
create_default
(
args
.
granularity
,
args
.
multi_cheapest_strategy
)
lines
=
open
(
args
.
file
,
"r"
,
encoding
=
"utf-8"
).
read
().
strip
().
split
(
"
\n
"
)
# format: id src tgt1 tgt2...
# error_types = []
with
open
(
args
.
output
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
count
=
0
sentence_set
=
set
()
sentence_to_tokenized
=
{}
for
line
in
lines
:
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
for
idx
,
sent
in
enumerate
(
sent_list
):
if
args
.
segmented
:
# print(sent)
sent
=
sent
.
strip
()
else
:
sent
=
""
.
join
(
sent
.
split
()).
strip
()
if
idx
>=
1
:
if
not
args
.
no_simplified
:
sentence_set
.
add
(
cc
.
convert
(
sent
))
else
:
sentence_set
.
add
(
sent
)
else
:
sentence_set
.
add
(
sent
)
batch
=
[]
for
sent
in
tqdm
(
sentence_set
):
count
+=
1
if
sent
:
batch
.
append
(
sent
)
if
count
%
args
.
batch_size
==
0
:
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
batch
=
[]
if
batch
:
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
timeout_indices
=
[]
# 单进程模式
for
idx
,
line
in
enumerate
(
tqdm
(
lines
)):
try
:
ret
=
annotate_with_time_out
(
line
)
except
Exception
:
timeout_indices
.
append
(
idx
)
return
timeout_indices
def
main
(
args
):
timeout_indices
=
firsttime_process
(
args
)
tokenizer
=
Tokenizer
(
args
.
granularity
,
args
.
device
,
args
.
segmented
,
args
.
bpe
)
global
annotator
,
sentence_to_tokenized
annotator
=
Annotator
.
create_default
(
args
.
granularity
,
args
.
multi_cheapest_strategy
)
lines
=
open
(
args
.
file
,
"r"
,
encoding
=
"utf-8"
).
read
().
strip
().
split
(
"
\n
"
)
new_lines
=
[]
# format: id src tgt1 tgt2...
with
open
(
args
.
output
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
count
=
0
sentence_set
=
set
()
sentence_to_tokenized
=
{}
for
line_idx
,
line
in
enumerate
(
lines
):
if
line_idx
in
timeout_indices
:
# print(f"line before split: {line}")
line_split
=
line
.
split
(
"
\t
"
)
line_number
,
sent_list
=
line_split
[
0
],
line_split
[
1
:]
assert
len
(
sent_list
)
==
2
sent_list
[
-
1
]
=
" 无"
line
=
line_number
+
"
\t
"
+
"
\t
"
.
join
(
sent_list
)
# print(f"line time out: {line}")
new_lines
.
append
(
line
)
else
:
new_lines
.
append
(
line
)
sent_list
=
line
.
split
(
"
\t
"
)[
1
:]
for
idx
,
sent
in
enumerate
(
sent_list
):
if
args
.
segmented
:
# print(sent)
sent
=
sent
.
strip
()
else
:
sent
=
""
.
join
(
sent
.
split
()).
strip
()
if
idx
>=
1
:
if
not
args
.
no_simplified
:
sentence_set
.
add
(
cc
.
convert
(
sent
))
else
:
sentence_set
.
add
(
sent
)
else
:
sentence_set
.
add
(
sent
)
batch
=
[]
for
sent
in
tqdm
(
sentence_set
):
count
+=
1
if
sent
:
batch
.
append
(
sent
)
if
count
%
args
.
batch_size
==
0
:
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
batch
=
[]
if
batch
:
results
=
tokenizer
(
batch
)
for
s
,
r
in
zip
(
batch
,
results
):
sentence_to_tokenized
[
s
]
=
r
# Get tokenization map.
# 单进程模式
lines
=
new_lines
for
idx
,
line
in
enumerate
(
tqdm
(
lines
)):
ret
=
annotate
(
line
)
f
.
write
(
ret
)
f
.
write
(
"
\n
"
)
# 多进程模式:仅在Linux环境下测试,建议在linux服务器上使用
# with Pool(args.worker_num) as pool:
# for ret in pool.imap(annotate, tqdm(lines), chunksize=8):
# if ret:
# f.write(ret)
# f.write("\n")
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Choose input file to annotate"
)
parser
.
add_argument
(
"-f"
,
"--file"
,
type
=
str
,
required
=
True
,
help
=
"Input parallel file"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
type
=
str
,
help
=
"Output file"
,
required
=
True
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
type
=
int
,
help
=
"The size of batch"
,
default
=
128
)
parser
.
add_argument
(
"-d"
,
"--device"
,
type
=
int
,
help
=
"The ID of GPU"
,
default
=
0
)
parser
.
add_argument
(
"-w"
,
"--worker_num"
,
type
=
int
,
help
=
"The number of workers"
,
default
=
16
)
parser
.
add_argument
(
"-g"
,
"--granularity"
,
type
=
str
,
help
=
"Choose char-level or word-level evaluation"
,
default
=
"char"
)
parser
.
add_argument
(
"-m"
,
"--merge"
,
help
=
"Whether merge continuous replacement/deletion/insertion"
,
action
=
"store_true"
)
parser
.
add_argument
(
"-s"
,
"--multi_cheapest_strategy"
,
type
=
str
,
choices
=
[
"first"
,
"all"
],
default
=
"all"
)
parser
.
add_argument
(
"--segmented"
,
help
=
"Whether tokens have been segmented"
,
action
=
"store_true"
)
# 支持提前token化,用空格隔开
parser
.
add_argument
(
"--no_simplified"
,
help
=
"Whether simplifying chinese"
,
action
=
"store_true"
)
# 将所有corrections转换为简体中文
parser
.
add_argument
(
"--bpe"
,
help
=
"Whether to use bpe"
,
action
=
"store_true"
)
# 支持 bpe 切分英文单词
args
=
parser
.
parse_args
()
main
(
args
)
opencompass/datasets/lawbench/utils/rc_f1.py
0 → 100644
View file @
861942ab
"""Official evaluation script for CAIL-2021.
The code is based partially on CoQA evaluation script.
"""
import
json
import
sys
from
collections
import
Counter
class
CJRCEvaluator
:
def
__init__
(
self
,
gold_file
):
self
.
gold_data
=
CJRCEvaluator
.
gold_answers_to_dict
(
gold_file
)
@
staticmethod
def
gold_answers_to_dict
(
gold_file
):
dataset
=
json
.
load
(
open
(
gold_file
,
mode
=
"r"
,
encoding
=
"utf-8"
))
gold_dict
=
{}
# id_to_domain = {}
for
story
in
dataset
[
'data'
]:
qas
=
story
[
"paragraphs"
][
0
][
"qas"
]
for
qa
in
qas
:
qid
=
qa
[
'id'
]
gold_answers
=
[]
answers
=
qa
[
"answers"
]
if
len
(
answers
)
==
0
:
gold_answers
=
[
''
]
else
:
for
answer
in
qa
[
"answers"
]:
if
type
(
answer
)
==
dict
:
gold_answers
.
append
(
answer
[
"text"
])
elif
type
(
answer
)
==
list
:
gold_answers
.
append
(
""
.
join
([
a
[
"text"
]
for
a
in
answer
]))
if
qid
in
gold_dict
:
sys
.
stderr
.
write
(
"Gold file has duplicate stories: {}"
.
format
(
qid
))
gold_dict
[
qid
]
=
gold_answers
return
gold_dict
@
staticmethod
def
preds_to_dict
(
pred_file
):
preds
=
json
.
load
(
open
(
pred_file
,
mode
=
"r"
,
encoding
=
"utf-8"
))
pred_dict
=
{}
for
pred
in
preds
:
pred_dict
[
pred
[
'id'
]]
=
""
.
join
(
pred
[
'answer'
])
return
pred_dict
@
staticmethod
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, storys and extra whitespace."""
def
remove_punc
(
text
):
return
""
.
join
(
ch
for
ch
in
text
if
ch
.
isdigit
()
or
ch
.
isalpha
())
def
lower
(
text
):
return
text
.
lower
()
return
remove_punc
(
lower
(
s
))
@
staticmethod
def
get_tokens
(
s
):
if
not
s
:
return
[]
return
list
(
CJRCEvaluator
.
normalize_answer
(
s
))
@
staticmethod
def
compute_exact
(
a_gold
,
a_pred
):
return
int
(
CJRCEvaluator
.
normalize_answer
(
a_gold
)
==
CJRCEvaluator
.
normalize_answer
(
a_pred
))
@
staticmethod
def
compute_f1
(
a_gold
,
a_pred
):
gold_toks
=
CJRCEvaluator
.
get_tokens
(
a_gold
)
pred_toks
=
CJRCEvaluator
.
get_tokens
(
a_pred
)
common
=
Counter
(
gold_toks
)
&
Counter
(
pred_toks
)
num_same
=
sum
(
common
.
values
())
if
len
(
gold_toks
)
==
0
or
len
(
pred_toks
)
==
0
:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return
int
(
gold_toks
==
pred_toks
)
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
pred_toks
)
recall
=
1.0
*
num_same
/
len
(
gold_toks
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
@
staticmethod
def
_compute_turn_score
(
a_gold_list
,
a_pred
):
f1_sum
=
0.0
em_sum
=
0.0
if
len
(
a_gold_list
)
>
1
:
for
i
in
range
(
len
(
a_gold_list
)):
# exclude the current answer
gold_answers
=
a_gold_list
[
0
:
i
]
+
a_gold_list
[
i
+
1
:]
em_sum
+=
max
(
CJRCEvaluator
.
compute_exact
(
a
,
a_pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
CJRCEvaluator
.
compute_f1
(
a
,
a_pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
CJRCEvaluator
.
compute_exact
(
a
,
a_pred
)
for
a
in
a_gold_list
)
f1_sum
+=
max
(
CJRCEvaluator
.
compute_f1
(
a
,
a_pred
)
for
a
in
a_gold_list
)
if
f1_sum
!=
1
:
a
=
1
+
1
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
a_gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
a_gold_list
))}
def
compute_turn_score
(
self
,
qid
,
a_pred
):
''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. '''
a_gold_list
=
self
.
gold_data
[
qid
]
return
CJRCEvaluator
.
_compute_turn_score
(
a_gold_list
,
a_pred
)
def
get_raw_scores
(
self
,
pred_data
):
''''Returns a dict with score'''
exact_scores
=
{}
f1_scores
=
{}
for
qid
in
self
.
gold_data
:
if
qid
not
in
pred_data
:
sys
.
stderr
.
write
(
'Missing prediction for {}
\n
'
.
format
(
qid
))
continue
a_pred
=
pred_data
[
qid
]
scores
=
self
.
compute_turn_score
(
qid
,
a_pred
)
# Take max over all gold answers
exact_scores
[
qid
]
=
scores
[
'em'
]
f1_scores
[
qid
]
=
scores
[
'f1'
]
return
exact_scores
,
f1_scores
def
get_raw_scores_human
(
self
):
'''
Returns a dict with score
'''
exact_scores
=
{}
f1_scores
=
{}
for
qid
in
self
.
gold_data
:
f1_sum
=
0.0
em_sum
=
0.0
if
len
(
self
.
gold_data
[
qid
])
>
1
:
for
i
in
range
(
len
(
self
.
gold_data
[
qid
])):
# exclude the current answer
gold_answers
=
self
.
gold_data
[
qid
][
0
:
i
]
+
self
.
gold_data
[
qid
][
i
+
1
:]
em_sum
+=
max
(
CJRCEvaluator
.
compute_exact
(
a
,
self
.
gold_data
[
qid
][
i
])
for
a
in
gold_answers
)
f1_sum
+=
max
(
CJRCEvaluator
.
compute_f1
(
a
,
self
.
gold_data
[
qid
][
i
])
for
a
in
gold_answers
)
else
:
exit
(
"Gold answers should be multiple: {}={}"
.
format
(
qid
,
self
.
gold_data
[
qid
]))
exact_scores
[
qid
]
=
em_sum
/
len
(
self
.
gold_data
[
qid
])
f1_scores
[
qid
]
=
f1_sum
/
len
(
self
.
gold_data
[
qid
])
return
exact_scores
,
f1_scores
def
human_performance
(
self
):
exact_scores
,
f1_scores
=
self
.
get_raw_scores_human
()
return
self
.
get_total_scores
(
exact_scores
,
f1_scores
)
def
model_performance
(
self
,
pred_data
):
exact_scores
,
f1_scores
=
self
.
get_raw_scores
(
pred_data
)
return
self
.
get_total_scores
(
exact_scores
,
f1_scores
)
def
get_total_scores
(
self
,
exact_scores
,
f1_scores
):
em_total
,
f1_total
,
turn_count
=
0
,
0
,
0
scores
=
{}
for
qid
in
self
.
gold_data
:
em_total
+=
exact_scores
.
get
(
qid
,
0
)
f1_total
+=
f1_scores
.
get
(
qid
,
0
)
turn_count
+=
1
scores
[
"F1"
]
=
round
(
f1_total
/
max
(
1
,
turn_count
)
*
100
,
1
)
return
scores
requirements/runtime.txt
View file @
861942ab
absl-py
absl-py
accelerate>=0.19.0
accelerate>=0.19.0
boto3
boto3
cn2an
colossalai
colossalai
cpm_kernels
cpm_kernels
datasets>=2.12.0
datasets>=2.12.0
...
@@ -9,11 +10,15 @@ fairscale
...
@@ -9,11 +10,15 @@ fairscale
faiss_gpu==1.7.2
faiss_gpu==1.7.2
fuzzywuzzy
fuzzywuzzy
jieba
jieba
ltp
mmengine>=0.8.2
mmengine>=0.8.2
nltk==3.8
nltk==3.8
numpy==1.23.4
numpy==1.23.4
openai
openai
OpenCC
pandas<2.0.0
pandas<2.0.0
pypinyin
python-Levenshtein
rank_bm25==0.2.2
rank_bm25==0.2.2
rapidfuzz
rapidfuzz
requests==2.31.0
requests==2.31.0
...
@@ -25,6 +30,7 @@ seaborn
...
@@ -25,6 +30,7 @@ seaborn
sentence_transformers==2.2.2
sentence_transformers==2.2.2
tabulate
tabulate
tiktoken
tiktoken
timeout_decorator
tokenizers>=0.13.3
tokenizers>=0.13.3
torch>=1.13.1
torch>=1.13.1
tqdm==4.64.1
tqdm==4.64.1
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment