Commit ad8941e3 authored by Shuimo's avatar Shuimo
Browse files

update the text_badcase script and add auto upload s3 function

parent 2dcf477d
{
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1_score": 1.0,
"pdf间的平均编辑距离": 19.82051282051282,
"pdf间的平均bleu": 0.9002485609584511,
"阅读顺序编辑距离": 0.3176895306859206,
"分段准确率": 0.8989169675090253,
"行内公式准确率": {
"accuracy": 0.9782741738066095,
"precision": 0.9782741738066095,
"recall": 1.0,
"f1_score": 0.9890177880897139
},
"行内公式编辑距离": 0.0,
"行内公式bleu": 0.20340450120213166,
"行间公式准确率": {
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1_score": 1.0
},
"行间公式编辑距离": 0.0,
"行间公式bleu": 0.3662262622386575,
"丢弃文本准确率": {
"accuracy": 0.867870036101083,
"precision": 0.9064856711915535,
"recall": 0.9532117367168914,
"f1_score": 0.9292616930807885
},
"丢弃文本标签准确率": {
"color_background_header_txt_block": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 41.0
},
"rotate": {
"precision": 1.0,
"recall": 0.9682539682539683,
"f1-score": 0.9838709677419355,
"support": 63.0
},
"footnote": {
"precision": 1.0,
"recall": 0.883495145631068,
"f1-score": 0.9381443298969072,
"support": 103.0
},
"header": {
"precision": 1.0,
"recall": 1.0,
"f1-score": 1.0,
"support": 4.0
},
"on-image": {
"precision": 0.9947643979057592,
"recall": 1.0,
"f1-score": 0.9973753280839895,
"support": 380.0
},
"on-table": {
"precision": 1.0,
"recall": 0.9443609022556391,
"f1-score": 0.97138437741686,
"support": 665.0
},
"micro avg": {
"precision": 0.9982847341337907,
"recall": 0.9267515923566879,
"f1-score": 0.9611890999174236,
"support": 1256.0
}
},
"丢弃图片准确率": {
"accuracy": 0.8666666666666667,
"precision": 0.9285714285714286,
"recall": 0.9285714285714286,
"f1_score": 0.9285714285714286
},
"丢弃表格准确率": {
"accuracy": 0,
"precision": 0,
"recall": 0,
"f1_score": 0
}
}
\ No newline at end of file
No preview for this file type
......@@ -432,75 +432,8 @@ def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_pag
def check_json_files_in_zip_exist(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
检查ZIP文件中是否存在指定的JSON文件
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
# 获取ZIP文件中所有文件的列表
all_files_in_zip = z.namelist()
# 检查标准文件和测试文件是否都在ZIP文件中
if standard_json_path_in_zip not in all_files_in_zip or test_json_path_in_zip not in all_files_in_zip:
raise FileNotFoundError("One or both of the required JSON files are missing from the ZIP archive.")
def read_json_files_from_streams(standard_file_stream, test_file_stream):
"""
从文件流中读取JSON文件内容
"""
pdf_json_standard = [json.loads(line) for line in standard_file_stream]
pdf_json_test = [json.loads(line) for line in test_file_stream]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test_origin = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test_origin
def read_json_files_from_zip(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
从ZIP文件中读取两个JSON文件并返回它们的DataFrame
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
with z.open(standard_json_path_in_zip) as standard_file_stream, \
z.open(test_json_path_in_zip) as test_file_stream:
standard_file_text_stream = TextIOWrapper(standard_file_stream, encoding='utf-8')
test_file_text_stream = TextIOWrapper(test_file_stream, encoding='utf-8')
json_standard_origin, json_test_origin = read_json_files_from_streams(
standard_file_text_stream, test_file_text_stream
)
return json_standard_origin, json_test_origin
def merge_json_data(json_test_df, json_standard_df):
"""
基于ID合并测试和标准数据集,并返回合并后的数据及存在性检查结果。
参数:
- json_test_df: 测试数据的DataFrame。
- json_standard_df: 标准数据的DataFrame。
返回:
- inner_merge: 内部合并的DataFrame,包含匹配的数据行。
- standard_exist: 标准数据存在性的Series。
- test_exist: 测试数据存在性的Series。
"""
test_data = json_test_df[['id', 'mid_json']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
standard_data = json_standard_df[['id', 'mid_json', 'pass_label']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
outer_merge = pd.merge(test_data, standard_data, on='id', how='outer')
outer_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
standard_exist = outer_merge.standard_mid_json.notnull()
test_exist = outer_merge.test_mid_json.notnull()
inner_merge = pd.merge(test_data, standard_data, on='id', how='inner')
inner_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
return inner_merge, standard_exist, test_exist
def consolidate_data(test_data, standard_data, key_path):
"""
......@@ -533,6 +466,20 @@ def consolidate_data(test_data, standard_data, key_path):
return overall_data_standard, overall_data_test
def overall_calculate_metrics(inner_merge, json_test, json_standard,standard_exist, test_exist):
"""
计算整体的指标,包括准确率、精确率、召回率、F1值、平均编辑距离、平均BLEU得分、分段准确率、公式准确率、公式编辑距离、公式BLEU、丢弃文本准确率、丢弃文本标签准确率、丢弃图片准确率、丢弃表格准确率等。
Args:
inner_merge (dict): 包含merge信息的字典,包括pass_label和id等信息。
json_test (dict): 测试集的json数据。
json_standard (dict): 标准集的json数据。
standard_exist (list): 标准集中存在的id列表。
test_exist (list): 测试集中存在的id列表。
Returns:
dict: 包含整体指标值的字典。
"""
process_data_standard = process_equations_and_blocks(json_standard, is_standard=True)
process_data_test = process_equations_and_blocks(json_test, is_standard=False)
......@@ -739,7 +686,75 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
return result_dict
def check_json_files_in_zip_exist(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
检查ZIP文件中是否存在指定的JSON文件
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
# 获取ZIP文件中所有文件的列表
all_files_in_zip = z.namelist()
# 检查标准文件和测试文件是否都在ZIP文件中
if standard_json_path_in_zip not in all_files_in_zip or test_json_path_in_zip not in all_files_in_zip:
raise FileNotFoundError("One or both of the required JSON files are missing from the ZIP archive.")
def read_json_files_from_streams(standard_file_stream, test_file_stream):
"""
从文件流中读取JSON文件内容
"""
pdf_json_standard = [json.loads(line) for line in standard_file_stream]
pdf_json_test = [json.loads(line) for line in test_file_stream]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test_origin = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test_origin
def read_json_files_from_zip(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
从ZIP文件中读取两个JSON文件并返回它们的DataFrame
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
with z.open(standard_json_path_in_zip) as standard_file_stream, \
z.open(test_json_path_in_zip) as test_file_stream:
standard_file_text_stream = TextIOWrapper(standard_file_stream, encoding='utf-8')
test_file_text_stream = TextIOWrapper(test_file_stream, encoding='utf-8')
json_standard_origin, json_test_origin = read_json_files_from_streams(
standard_file_text_stream, test_file_text_stream
)
return json_standard_origin, json_test_origin
def merge_json_data(json_test_df, json_standard_df):
"""
基于ID合并测试和标准数据集,并返回合并后的数据及存在性检查结果。
参数:
- json_test_df: 测试数据的DataFrame。
- json_standard_df: 标准数据的DataFrame。
返回:
- inner_merge: 内部合并的DataFrame,包含匹配的数据行。
- standard_exist: 标准数据存在性的Series。
- test_exist: 测试数据存在性的Series。
"""
test_data = json_test_df[['id', 'mid_json']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
standard_data = json_standard_df[['id', 'mid_json', 'pass_label']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
outer_merge = pd.merge(test_data, standard_data, on='id', how='outer')
outer_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
standard_exist = outer_merge.standard_mid_json.notnull()
test_exist = outer_merge.test_mid_json.notnull()
inner_merge = pd.merge(test_data, standard_data, on='id', how='inner')
inner_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
return inner_merge, standard_exist, test_exist
def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
"""
......@@ -762,17 +777,25 @@ def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
print(f"计算结果已经保存到文件:{overall_path}")
def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_KEY,END_POINT_URL):
def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL):
"""
上传文件到Amazon S3
"""
s3 = boto3.client('s3',aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY,endpoint_url=END_POINT_URL)
# 创建S3客户端
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY, endpoint_url=END_POINT_URL)
try:
# 从文件路径中提取文件名
file_name = os.path.basename(file_path)
# 创建S3对象键,将s3_directory和file_name连接起来
s3_object_key = f"{s3_directory}/{file_name}" # 使用斜杠直接连接
# 上传文件到S3
s3.upload_file(file_path, bucket_name, s3_file_name)
print(f"文件 {s3_file_name} 成功上传到S3存储桶 {bucket_name} 中的路径 {file_path}")
s3.upload_file(file_path, bucket_name, s3_object_key)
print(f"文件 {file_path} 成功上传到S3存储桶 {bucket_name} 中的目录 {s3_directory},文件名为 {file_name}")
except FileNotFoundError:
print(f"文件 {s3_file_name} 未找到,请检查文件路径是否正确。")
print(f"文件 {file_path} 未找到,请检查文件路径是否正确。")
except NoCredentialsError:
print("无法找到AWS凭证,请确认您的AWS访问密钥和密钥ID是否正确。")
except ClientError as e:
......@@ -801,14 +824,15 @@ def compare_edit_distance(json_file, overall_report):
json_edit_distance = json_data['pdf间的平均编辑距离']
if overall_report['pdf间的平均编辑距离'] >= json_edit_distance:
if overall_report['pdf间的平均编辑距离'] > json_edit_distance:
return 0
else:
return 1
def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path,s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path, s3_bucket_name=None, s3_file_directory=None,
aws_access_key=None, aws_secret_key=None, end_point_url=None):
"""
主函数,执行整个评估流程。
......@@ -819,7 +843,7 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
- badcase_path: badcase文件的基础路径和文件名前缀。
- overall_path: overall文件的基础路径和文件名前缀。
- s3_bucket_name: S3桶名称(可选)。
- s3_file_name: S3上的文件(可选)。
- s3_file_directory: S3上的文件保存目录(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
"""
# 检查文件是否存在
......@@ -843,7 +867,13 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
save_results(result_dict, overall_report_dict,badcase_file,overall_file)
result=compare_edit_distance(base_data_path, overall_report_dict)
if all([s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url]):
try:
upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
upload_to_s3(overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
except Exception as e:
print(f"上传到S3时发生错误: {e}")
print(result)
if __name__ == "__main__":
......@@ -855,12 +885,12 @@ if __name__ == "__main__":
parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None)
parser.add_argument('--s3_file_directory', type=str, help='S3上的文件名。', default=None)
parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
parser.add_argument('--AWS_SECRET_KEY', type=str, help='AWS秘密密钥。', default=None)
parser.add_argument('--END_POINT_URL', type=str, help='AWS端点URL。', default=None)
args = parser.parse_args()
main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_directory, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment