Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
DataFlow
Commits
97e8278b
Commit
97e8278b
authored
Dec 03, 2025
by
zzg_666
Browse files
适配后端vllm
parents
Pipeline
#3071
canceled with stages
Changes
385
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3539 additions
and
0 deletions
+3539
-0
dataflow/operators/knowledge_cleaning/generate/file_or_url_to_markdown_converter_batch.py
...aning/generate/file_or_url_to_markdown_converter_batch.py
+283
-0
dataflow/operators/knowledge_cleaning/generate/kbc_chunk_generator.py
...rators/knowledge_cleaning/generate/kbc_chunk_generator.py
+157
-0
dataflow/operators/knowledge_cleaning/generate/kbc_chunk_generator_batch.py
.../knowledge_cleaning/generate/kbc_chunk_generator_batch.py
+176
-0
dataflow/operators/knowledge_cleaning/generate/kbc_multihop_qa_generator_batch.py
...edge_cleaning/generate/kbc_multihop_qa_generator_batch.py
+637
-0
dataflow/operators/knowledge_cleaning/generate/kbc_text_cleaner.py
...operators/knowledge_cleaning/generate/kbc_text_cleaner.py
+138
-0
dataflow/operators/knowledge_cleaning/generate/kbc_text_cleaner_batch.py
...ors/knowledge_cleaning/generate/kbc_text_cleaner_batch.py
+148
-0
dataflow/operators/knowledge_cleaning/generate/mathbook_question_extract.py
.../knowledge_cleaning/generate/mathbook_question_extract.py
+333
-0
dataflow/operators/knowledge_cleaning/generate/qa_extract.py
dataflow/operators/knowledge_cleaning/generate/qa_extract.py
+229
-0
dataflow/operators/pdf2vqa/__init__.py
dataflow/operators/pdf2vqa/__init__.py
+17
-0
dataflow/operators/pdf2vqa/generate/vqa_extractor.py
dataflow/operators/pdf2vqa/generate/vqa_extractor.py
+470
-0
dataflow/operators/reasoning/__init__.py
dataflow/operators/reasoning/__init__.py
+36
-0
dataflow/operators/reasoning/eval/reasoning_category_dataset_evaluator.py
...rs/reasoning/eval/reasoning_category_dataset_evaluator.py
+81
-0
dataflow/operators/reasoning/eval/reasoning_difficulty_dataset_evaluator.py
.../reasoning/eval/reasoning_difficulty_dataset_evaluator.py
+62
-0
dataflow/operators/reasoning/eval/reasoning_question_category_sample_evaluator.py
...ning/eval/reasoning_question_category_sample_evaluator.py
+120
-0
dataflow/operators/reasoning/eval/reasoning_question_difficulty_sample_evaluator.py
...ng/eval/reasoning_question_difficulty_sample_evaluator.py
+111
-0
dataflow/operators/reasoning/eval/reasoning_question_solvable_sample_evaluator.py
...ning/eval/reasoning_question_solvable_sample_evaluator.py
+91
-0
dataflow/operators/reasoning/eval/reasoning_token_dataset_evaluator.py
...ators/reasoning/eval/reasoning_token_dataset_evaluator.py
+96
-0
dataflow/operators/reasoning/filter/reasoning_answer_formatter_filter.py
...ors/reasoning/filter/reasoning_answer_formatter_filter.py
+83
-0
dataflow/operators/reasoning/filter/reasoning_answer_groundtruth_filter.py
...s/reasoning/filter/reasoning_answer_groundtruth_filter.py
+87
-0
dataflow/operators/reasoning/filter/reasoning_answer_model_judge_filter.py
...s/reasoning/filter/reasoning_answer_model_judge_filter.py
+184
-0
No files found.
Too many changes to show.
To preserve performance only
385 of 385+
files are displayed.
Plain diff
Email patch
dataflow/operators/knowledge_cleaning/generate/file_or_url_to_markdown_converter_batch.py
0 → 100644
View file @
97e8278b
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
import
os
from
pathlib
import
Path
from
trafilatura
import
fetch_url
,
extract
from
urllib.parse
import
urlparse
from
tqdm
import
tqdm
import
requests
def
is_url
(
string
):
try
:
result
=
urlparse
(
string
)
return
all
([
result
.
scheme
,
result
.
netloc
])
except
ValueError
:
return
False
def
_parse_file_with_mineru
(
raw_file
:
str
,
output_file
:
str
,
mineru_backend
:
str
=
"vlm-vllm-engine"
)
->
str
:
"""
Uses MinerU to parse PDF/image files (pdf/png/jpg/jpeg/webp/gif) into Markdown files.
Internally, the parsed outputs for each item are stored in a structured directory:
'intermediate_dir/pdf_name/MinerU_Version[mineru_backend]'.
This directory stores various MinerU parsing outputs, and you can customize
which content to extract based on your needs.
Args:
raw_file: Input file path, supports .pdf/.png/.jpg/.jpeg/.webp/.gif
output_file: Full path for the output Markdown file
mineru_backend: Sets the backend engine for MinerU. Options include:
- "pipeline": Traditional pipeline processing (MinerU1)
- "vlm-sglang-engine": New engine based on multimodal language models (MinerU2) (default recommended)
Choose the appropriate backend based on your needs. Defaults to "vlm-sglang-engine".
For more details, refer to the MinerU GitHub: https://github.com/opendatalab/MinerU.
Returns:
output_file: Path to the Markdown file
"""
try
:
import
mineru
except
ImportError
:
raise
Exception
(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
logger
=
get_logger
()
os
.
environ
[
'MINERU_MODEL_SOURCE'
]
=
"local"
# 可选:从本地加载模型
# pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client
MinerU_Version
=
{
"pipeline"
:
"auto"
,
"vlm-transformers"
:
"vlm"
,
'vlm-vllm-engine'
:
'vlm'
,
'vlm-http-client'
:
'vlm'
}
raw_file
=
Path
(
raw_file
)
# import pdb; pdb.set_trace()
pdf_name
=
Path
(
raw_file
).
stem
intermediate_dir
=
output_file
intermediate_dir
=
os
.
path
.
join
(
intermediate_dir
,
"mineru"
)
import
subprocess
command
=
[
"mineru"
,
"-p"
,
raw_file
,
"-o"
,
intermediate_dir
,
"-b"
,
mineru_backend
,
"--source"
,
"local"
]
try
:
result
=
subprocess
.
run
(
command
,
#stdout=subprocess.DEVNULL,
#stderr=subprocess.DEVNULL,
check
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to process file with MinerU:
{
str
(
e
)
}
"
)
# Directory for storing raw data, including various MinerU parsing outputs.
# You can customize which content to extract based on your needs.
PerItemDir
=
os
.
path
.
join
(
intermediate_dir
,
pdf_name
,
MinerU_Version
[
mineru_backend
])
output_file
=
os
.
path
.
join
(
PerItemDir
,
f
"
{
pdf_name
}
.md"
)
logger
.
info
(
f
"Markdown saved to:
{
output_file
}
"
)
return
output_file
def
_parse_xml_to_md
(
raw_file
:
str
=
None
,
url
:
str
=
None
,
output_file
:
str
=
None
):
logger
=
get_logger
()
if
(
url
):
downloaded
=
fetch_url
(
url
)
if
not
downloaded
:
downloaded
=
"fail to fetch this url. Please check your Internet Connection or URL correctness"
with
open
(
output_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
downloaded
)
return
output_file
elif
(
raw_file
):
with
open
(
raw_file
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
downloaded
=
f
.
read
()
else
:
raise
Exception
(
"Please provide at least one of file path and url string."
)
try
:
result
=
extract
(
downloaded
,
output_format
=
"markdown"
,
with_metadata
=
True
)
logger
.
info
(
f
"Extracted content is written into
{
output_file
}
"
)
with
open
(
output_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
result
)
except
Exception
as
e
:
logger
.
error
(
"Error during extract this file or link: "
,
e
)
return
output_file
def
is_pdf_url
(
url
):
try
:
# 发送HEAD请求,只获取响应头,不下载文件
response
=
requests
.
head
(
url
,
allow_redirects
=
True
)
# 如果响应的Content-Type是application/pdf
if
response
.
status_code
==
200
and
response
.
headers
.
get
(
'Content-Type'
)
==
'application/pdf'
:
return
True
else
:
print
(
f
"Content-Type:
{
response
.
headers
.
get
(
'Content-Type'
)
}
"
)
return
False
except
requests
.
exceptions
.
RequestException
:
# 如果请求失败,返回False
print
(
"Request failed"
)
return
False
def
download_pdf
(
url
,
save_path
):
try
:
# 发送GET请求下载PDF文件
response
=
requests
.
get
(
url
,
stream
=
True
)
# 确保响应内容是PDF
if
response
.
status_code
==
200
and
response
.
headers
.
get
(
'Content-Type'
)
==
'application/pdf'
:
# 将PDF保存到本地
pdf_folder
=
os
.
path
.
dirname
(
save_path
)
os
.
makedirs
(
pdf_folder
,
exist_ok
=
True
)
with
open
(
save_path
,
'wb'
)
as
f
:
for
chunk
in
response
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
f
.
write
(
chunk
)
print
(
f
"PDF saved to
{
save_path
}
"
)
else
:
print
(
"The URL did not return a valid PDF file."
)
except
requests
.
exceptions
.
RequestException
as
e
:
print
(
f
"Error downloading PDF:
{
e
}
"
)
@
OPERATOR_REGISTRY
.
register
()
class
FileOrURLToMarkdownConverterBatch
(
OperatorABC
):
"""
mineru_backend sets the backend engine for MinerU. Options include:
- "pipeline": Traditional pipeline processing (MinerU1)
- "vlm-sglang-engine": New engine based on multimodal language models (MinerU2) (default recommended)
Choose the appropriate backend based on your needs. Defaults to "vlm-sglang-engine".
For more details, refer to the MinerU GitHub: https://github.com/opendatalab/MinerU.
"""
def
__init__
(
self
,
intermediate_dir
:
str
=
"intermediate"
,
lang
:
str
=
"en"
,
mineru_backend
:
str
=
"vlm-sglang-engine"
):
self
.
logger
=
get_logger
()
self
.
intermediate_dir
=
intermediate_dir
os
.
makedirs
(
self
.
intermediate_dir
,
exist_ok
=
True
)
self
.
lang
=
lang
self
.
mineru_backend
=
mineru_backend
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
"""
返回算子功能描述 (根据run()函数的功能实现)
"""
if
lang
==
"zh"
:
return
(
"知识提取算子:支持从多种文件格式中提取结构化内容并转换为标准Markdown
\n
"
"核心功能:
\n
"
"1. PDF文件:使用MinerU解析引擎提取文本/表格/公式,保留原始布局
\n
"
"2. Office文档(DOC/PPT等):通过DocConverter转换为Markdown格式
\n
"
"3. 网页内容(HTML/XML):使用trafilatura提取正文并转为Markdown
\n
"
"4. 纯文本(TXT/MD):直接透传不做处理
\n
"
"特殊处理:
\n
"
"- 自动识别中英文文档(lang参数)
\n
"
"- 支持本地文件路径和URL输入
\n
"
"- 生成中间文件到指定目录(intermediate_dir)"
)
else
:
# 默认英文
return
(
"Knowledge Extractor: Converts multiple file formats to structured Markdown
\n
"
"Key Features:
\n
"
"1. PDF: Uses MinerU engine to extract text/tables/formulas with layout preservation
\n
"
"2. Office(DOC/PPT): Converts to Markdown via DocConverter
\n
"
"3. Web(HTML/XML): Extracts main content using trafilatura
\n
"
"4. Plaintext(TXT/MD): Directly passes through without conversion
\n
"
"Special Handling:
\n
"
"- Auto-detects Chinese/English documents(lang param)
\n
"
"- Supports both local files and URLs
\n
"
"- Generates intermediate files to specified directory(intermediate_dir)"
)
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
"source"
,
output_key
:
str
=
"text_path"
):
self
.
logger
.
info
(
"Starting content extraction..."
)
self
.
logger
.
info
(
"If the input is a URL or a large file, this process might take some time. Please wait..."
)
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
logger
.
info
(
f
"Loaded dataframe with
{
len
(
dataframe
)
}
entries."
)
output_file_all
=
[]
# Wrap iterrows with tqdm for progress tracking
for
index
,
row
in
tqdm
(
dataframe
.
iterrows
(),
total
=
len
(
dataframe
),
desc
=
"FileOrURLToMarkdownConverter Processing files"
,
ncols
=
100
):
content
=
row
.
get
(
input_key
,
""
)
if
is_url
(
content
):
# Case: Input is a URL
if
is_pdf_url
(
content
):
pdf_save_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
storage
.
first_entry_file_name
),
f
"raw/crawled/crawled_
{
index
}
.pdf"
)
self
.
logger
.
info
(
f
"Downloading PDF from
{
content
}
to
{
pdf_save_path
}
"
)
download_pdf
(
content
,
pdf_save_path
)
content
=
pdf_save_path
self
.
logger
.
info
(
f
"pdf file has been fetched and saved to
{
pdf_save_path
}
"
)
else
:
output_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
storage
.
first_entry_file_name
),
f
"raw/crawled/crawled_
{
index
}
.md"
)
os
.
makedirs
(
os
.
path
.
dirname
(
output_file
),
exist_ok
=
True
)
output_file
=
_parse_xml_to_md
(
url
=
content
,
output_file
=
output_file
)
self
.
logger
.
info
(
f
"Primary extracted result written to:
{
output_file
}
"
)
output_file_all
.
append
(
output_file
)
continue
# Extract file name and extension
raw_file
=
content
raw_file_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
raw_file
))[
0
]
raw_file_suffix
=
os
.
path
.
splitext
(
raw_file
)[
1
].
lower
()
raw_file_suffix_no_dot
=
raw_file_suffix
.
lstrip
(
"."
)
# Define default output path
output_file
=
os
.
path
.
join
(
self
.
intermediate_dir
,
f
"
{
raw_file_name
}
_
{
raw_file_suffix_no_dot
}
.md"
)
# Case: Local file path
if
not
os
.
path
.
exists
(
content
):
self
.
logger
.
error
(
f
"File not found: Path
{
content
}
does not exist."
)
output_file_all
.
append
(
""
)
continue
_
,
ext
=
os
.
path
.
splitext
(
content
)
ext
=
ext
.
lower
()
if
ext
in
[
".pdf"
,
".png"
,
".jpg"
,
".jpeg"
,
".webp"
,
".gif"
]:
self
.
logger
.
info
(
f
"Using MinerU backend:
{
self
.
mineru_backend
}
"
)
output_file
=
_parse_file_with_mineru
(
raw_file
=
content
,
output_file
=
self
.
intermediate_dir
,
mineru_backend
=
self
.
mineru_backend
)
elif
ext
in
[
".html"
,
".xml"
]:
output_file
=
_parse_xml_to_md
(
raw_file
=
content
,
output_file
=
output_file
)
elif
ext
in
[
".txt"
,
".md"
]:
output_file
=
content
# No parsing needed for plain text or Markdown files
else
:
self
.
logger
.
error
(
f
"Unsupported file type:
{
ext
}
for file
{
content
}
"
)
output_file
=
""
output_file_all
.
append
(
output_file
)
# Save results back to storage
dataframe
[
output_key
]
=
output_file_all
output_file_path
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Final extraction results saved to:
{
output_file_path
}
"
)
return
output_file_path
dataflow/operators/knowledge_cleaning/generate/kbc_chunk_generator.py
0 → 100644
View file @
97e8278b
import
os
import
json
from
typing
import
Dict
,
List
,
Optional
from
chonkie
import
(
TokenChunker
,
SentenceChunker
,
SemanticChunker
,
RecursiveChunker
)
from
tokenizers
import
Tokenizer
from
transformers
import
AutoTokenizer
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
@
OPERATOR_REGISTRY
.
register
()
class
KBCChunkGenerator
(
OperatorABC
):
def
__init__
(
self
,
chunk_size
:
int
=
512
,
chunk_overlap
:
int
=
50
,
split_method
:
str
=
"token"
,
min_tokens_per_chunk
:
int
=
100
,
tokenizer_name
:
str
=
"bert-base-uncased"
,
):
# 必需参数检查
self
.
chunk_size
=
chunk_size
self
.
chunk_overlap
=
chunk_overlap
self
.
split_method
=
split_method
self
.
min_tokens_per_chunk
=
min_tokens_per_chunk
tokenizer_name
=
tokenizer_name
# 初始化tokenizer和chunker
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
)
self
.
chunker
=
self
.
_initialize_chunker
()
self
.
logger
=
get_logger
()
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
(
lang
==
"zh"
):
return
(
"CorpusTextSplitter是轻量级文本分割工具,"
,
"支持词/句/语义/递归分块,"
,
"可配置块大小、重叠和最小块长度"
,
)
elif
(
lang
==
"en"
):
return
(
"CorpusTextSplitter is a lightweight text segmentation tool"
,
"that supports multiple chunking methods"
,
"(token/sentence/semantic/recursive) with configurable size and overlap,"
,
"optimized for RAG applications."
)
def
_initialize_chunker
(
self
):
"""Initialize the appropriate chunker based on method"""
if
self
.
split_method
==
"token"
:
return
TokenChunker
(
tokenizer
=
self
.
tokenizer
,
chunk_size
=
self
.
chunk_size
,
chunk_overlap
=
self
.
chunk_overlap
)
elif
self
.
split_method
==
"sentence"
:
return
SentenceChunker
(
chunk_size
=
self
.
chunk_size
,
chunk_overlap
=
self
.
chunk_overlap
)
elif
self
.
split_method
==
"semantic"
:
return
SemanticChunker
(
chunk_size
=
self
.
chunk_size
,
)
elif
self
.
split_method
==
"recursive"
:
return
RecursiveChunker
(
chunk_size
=
self
.
chunk_size
,
chunk_overlap
=
self
.
chunk_overlap
)
else
:
raise
ValueError
(
f
"Unsupported split method:
{
self
.
split_method
}
"
)
def
_load_text
(
self
,
file_path
)
->
str
:
"""Load text from input file"""
if
not
os
.
path
.
exists
(
file_path
):
raise
FileNotFoundError
(
f
"Input file not found:
{
file_path
}
"
)
if
file_path
.
endswith
(
'.txt'
)
or
file_path
.
endswith
(
'.md'
)
or
file_path
.
endswith
(
'.xml'
):
with
open
(
file_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
return
f
.
read
()
elif
file_path
.
endswith
((
'.json'
,
'.jsonl'
)):
with
open
(
file_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)
if
file_path
.
endswith
(
'.json'
)
else
[
json
.
loads
(
line
)
for
line
in
f
]
text_fields
=
[
'text'
,
'content'
,
'body'
]
for
field
in
text_fields
:
if
isinstance
(
data
,
list
)
and
len
(
data
)
>
0
and
field
in
data
[
0
]:
return
"
\n
"
.
join
([
item
[
field
]
for
item
in
data
])
elif
isinstance
(
data
,
dict
)
and
field
in
data
:
return
data
[
field
]
raise
ValueError
(
"No text field found in JSON input"
)
else
:
raise
ValueError
(
"Unsupported file format"
)
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
forbidden_keys
=
[
self
.
output_key
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
'text_path'
,
output_key
:
str
=
"raw_chunk"
):
"""Perform text splitting and save results"""
# try:
self
.
input_key
=
input_key
self
.
output_key
=
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
text_paths
=
dataframe
[
self
.
input_key
].
tolist
()
for
input_path
in
text_paths
:
if
not
input_path
or
not
os
.
path
.
exists
(
input_path
):
self
.
logger
.
error
(
f
"无效的输入文件路径:
{
input_path
}
"
)
new_records
=
[]
for
row_dict
,
text_path
in
zip
(
dataframe
.
to_dict
(
orient
=
'records'
),
text_paths
):
text
=
self
.
_load_text
(
text_path
)
if
(
text
):
# 计算总token数和最大限制
tokens
=
self
.
tokenizer
.
encode
(
text
)
total_tokens
=
len
(
tokens
)
max_tokens
=
self
.
tokenizer
.
model_max_length
# 假设这是tokenizer的最大token限制
print
(
"max_tokens: "
,
self
.
tokenizer
.
model_max_length
)
if
total_tokens
<=
max_tokens
:
chunks
=
self
.
chunker
(
text
)
else
:
# 计算需要分割的份数x(向上取整)
x
=
(
total_tokens
+
max_tokens
-
1
)
//
max_tokens
# 按词数等分文本(近似分割)
words
=
text
.
split
()
# 按空格分词
words_per_chunk
=
(
len
(
words
)
+
x
-
1
)
//
x
# 每份的词数
chunks
=
[]
for
i
in
range
(
0
,
len
(
words
),
words_per_chunk
):
chunk_text
=
' '
.
join
(
words
[
i
:
i
+
words_per_chunk
])
chunks
.
extend
(
self
.
chunker
(
chunk_text
))
# 每个chunk生成一条记录
for
chunk
in
chunks
:
new_row
=
row_dict
.
copy
()
# 保留原行里所有字段(不会改动原 dataframe 的其他 key)
new_row
[
self
.
output_key
]
=
chunk
.
text
# 新增/覆盖 output_key 字段
new_records
.
append
(
new_row
)
new_df
=
pd
.
DataFrame
(
new_records
)
output_file
=
storage
.
write
(
new_df
)
self
.
logger
.
info
(
f
"Successfully split text for
{
len
(
text_paths
)
}
files. Saved to
{
output_file
}
"
)
return
[
output_key
]
\ No newline at end of file
dataflow/operators/knowledge_cleaning/generate/kbc_chunk_generator_batch.py
0 → 100644
View file @
97e8278b
import
os
import
json
from
typing
import
Dict
,
List
,
Optional
from
chonkie
import
(
TokenChunker
,
SentenceChunker
,
SemanticChunker
,
RecursiveChunker
)
from
tokenizers
import
Tokenizer
from
transformers
import
AutoTokenizer
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
@
OPERATOR_REGISTRY
.
register
()
class
KBCChunkGeneratorBatch
(
OperatorABC
):
def
__init__
(
self
,
chunk_size
:
int
=
512
,
chunk_overlap
:
int
=
50
,
split_method
:
str
=
"token"
,
min_tokens_per_chunk
:
int
=
100
,
tokenizer_name
:
str
=
"bert-base-uncased"
,
):
# 必需参数检查
self
.
chunk_size
=
chunk_size
self
.
chunk_overlap
=
chunk_overlap
self
.
split_method
=
split_method
self
.
min_tokens_per_chunk
=
min_tokens_per_chunk
tokenizer_name
=
tokenizer_name
# 初始化tokenizer和chunker
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
)
self
.
chunker
=
self
.
_initialize_chunker
()
self
.
logger
=
get_logger
()
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
(
lang
==
"zh"
):
return
(
"CorpusTextSplitter是轻量级文本分割工具,"
,
"支持词/句/语义/递归分块,"
,
"可配置块大小、重叠和最小块长度"
,
)
elif
(
lang
==
"en"
):
return
(
"CorpusTextSplitter is a lightweight text segmentation tool"
,
"that supports multiple chunking methods"
,
"(token/sentence/semantic/recursive) with configurable size and overlap,"
,
"optimized for RAG applications."
)
def
_initialize_chunker
(
self
):
"""Initialize the appropriate chunker based on method"""
if
self
.
split_method
==
"token"
:
return
TokenChunker
(
tokenizer
=
self
.
tokenizer
,
chunk_size
=
self
.
chunk_size
,
chunk_overlap
=
self
.
chunk_overlap
)
elif
self
.
split_method
==
"sentence"
:
return
SentenceChunker
(
chunk_size
=
self
.
chunk_size
,
chunk_overlap
=
self
.
chunk_overlap
)
elif
self
.
split_method
==
"semantic"
:
return
SemanticChunker
(
chunk_size
=
self
.
chunk_size
,
)
elif
self
.
split_method
==
"recursive"
:
return
RecursiveChunker
(
chunk_size
=
self
.
chunk_size
,
chunk_overlap
=
self
.
chunk_overlap
)
else
:
raise
ValueError
(
f
"Unsupported split method:
{
self
.
split_method
}
"
)
def
_load_text
(
self
,
text_paths
:
List
[
str
])
->
List
[
str
]:
"""Load text from file list"""
texts
=
[]
for
text_path
in
text_paths
:
if
not
os
.
path
.
exists
(
text_path
):
self
.
logger
.
error
(
f
"Input file not found:
{
text_path
}
"
)
texts
.
append
(
""
)
elif
text_path
.
endswith
(
'.txt'
)
or
text_path
.
endswith
(
'.md'
)
or
text_path
.
endswith
(
'.xml'
):
with
open
(
text_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
texts
.
append
(
f
.
read
())
elif
text_path
.
endswith
((
'.json'
,
'.jsonl'
)):
with
open
(
text_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)
if
text_path
.
endswith
(
'.json'
)
else
[
json
.
loads
(
line
)
for
line
in
f
]
text_fields
=
[
'text'
,
'content'
,
'body'
]
for
field
in
text_fields
:
if
isinstance
(
data
,
list
)
and
len
(
data
)
>
0
and
field
in
data
[
0
]:
texts
.
append
(
"
\n
"
.
join
([
item
[
field
]
for
item
in
data
]))
elif
isinstance
(
data
,
dict
)
and
field
in
data
:
texts
.
append
(
data
[
field
])
if
(
field
not
in
text_fields
):
raise
ValueError
(
"No text field found in JSON input"
)
else
:
raise
ValueError
(
f
"Unsupported file format for
{
text_path
}
"
)
return
texts
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
forbidden_keys
=
[
self
.
output_key
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
"text_path"
,
output_key
:
str
=
"chunk_path"
):
"""Perform text splitting and save results"""
# try:
self
.
input_key
=
input_key
self
.
output_key
=
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
text_paths
=
dataframe
[
self
.
input_key
].
tolist
()
texts
=
self
.
_load_text
(
text_paths
)
output_paths
=
[]
chunks
=
[]
for
i
,
text
in
enumerate
(
texts
):
if
(
text
):
# 计算总token数和最大限制
tokens
=
self
.
tokenizer
.
encode
(
text
)
total_tokens
=
len
(
tokens
)
max_tokens
=
self
.
tokenizer
.
model_max_length
# 假设这是tokenizer的最大token限制
print
(
"max_tokens: "
,
self
.
tokenizer
.
model_max_length
)
if
total_tokens
<=
max_tokens
:
chunks
=
self
.
chunker
(
text
)
else
:
# 计算需要分割的份数x(向上取整)
x
=
(
total_tokens
+
max_tokens
-
1
)
//
max_tokens
# 按词数等分文本(近似分割)
words
=
text
.
split
()
# 按空格分词
words_per_chunk
=
(
len
(
words
)
+
x
-
1
)
//
x
# 每份的词数
chunks
=
[]
for
j
in
range
(
0
,
len
(
words
),
words_per_chunk
):
chunk_text
=
' '
.
join
(
words
[
j
:
j
+
words_per_chunk
])
chunks
.
extend
(
self
.
chunker
(
chunk_text
))
json_chunks
=
[{
"raw_chunk"
:
chunk
.
text
,
}
for
chunk
in
chunks
]
output_dir
=
"/"
.
join
([
os
.
path
.
dirname
(
text_paths
[
i
]),
"extract"
])
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
file_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
text_paths
[
i
]))[
0
]
+
'_chunk.json'
output_path
=
os
.
path
.
join
(
output_dir
,
file_name
)
output_paths
.
append
(
output_path
)
with
open
(
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
json_chunks
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
self
.
logger
.
info
(
f
"Successfully split
{
text_paths
[
i
]
}
into
{
len
(
chunks
)
}
chunks. Saved to
{
output_path
}
"
)
else
:
output_paths
.
append
(
""
)
print
(
">>>>>>>>>>>>>>>>>>>>>>>>>>>"
)
print
(
output_paths
)
print
(
"<<<<<<<<<<<<<<<<<<<<<<<<<<<"
)
dataframe
[
self
.
output_key
]
=
output_paths
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Successfully split text into
{
len
(
chunks
)
}
chunks. Saved to
{
output_file
}
"
)
return
[
output_key
]
dataflow/operators/knowledge_cleaning/generate/kbc_multihop_qa_generator_batch.py
0 → 100644
View file @
97e8278b
from
dataflow.prompts.text2qa
import
Text2MultiHopQAGeneratorPrompt
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
from
dataflow.core
import
LLMServingABC
import
random
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
import
json
from
tqdm
import
tqdm
import
re
from
dataflow.core.prompt
import
prompt_restrict
,
DIYPromptABC
from
typing
import
Union
@
prompt_restrict
(
Text2MultiHopQAGeneratorPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
KBCMultiHopQAGeneratorBatch
(
OperatorABC
):
r
"""A processor for generating multi-hop question-answer pairs from user
data.
This class handles the processing of text data to generate multi-hop
question-answer pairs using either an AI model or rule-based approaches.
It manages the entire pipeline from text preprocessing to dataset curation.
"""
def
__init__
(
self
,
llm_serving
:
LLMServingABC
,
seed
:
int
=
0
,
lang
=
"en"
,
prompt_template
:
Union
[
Text2MultiHopQAGeneratorPrompt
,
DIYPromptABC
]
=
None
):
r
"""Initialize the UserDataProcessor.
Args:
config (Optional[ProcessorConfig], optional): Configuration for
data processing. (default: :obj:`None`)
"""
self
.
rng
=
random
.
Random
(
seed
)
self
.
llm_serving
=
llm_serving
self
.
lang
=
lang
self
.
logger
=
get_logger
()
if
prompt_template
:
self
.
prompt_template
=
prompt_template
else
:
self
.
prompt_template
=
Text2MultiHopQAGeneratorPrompt
()
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
)
->
tuple
:
"""Returns a description of the processor's functionality.
Args:
lang (str, optional): Language for description ('zh' or 'en').
Returns:
tuple: Description strings in specified language, including format example
"""
if
lang
==
"zh"
:
return
(
"MultiHopQAGenerator 是多跳问答对生成处理器,支持从文本中自动生成需要多步推理的问题与答案。"
,
"处理流程包括:文本预处理、信息抽取、问题生成与回答生成,支持自定义语言模型后端和参数。"
,
"输出格式如下:"
,
"输入:
\n
"
"text: <原始上下文文本>"
,
"输出:
\n
"
"{
\n
"
"
\"
text
\"
: <处理后的文本字符串>,
\n
"
"
\"
qa_pairs
\"
: [
\n
"
" {
\n
"
"
\"
question
\"
: <字符串:生成的问题>,
\n
"
"
\"
reasoning_steps
\"
: [
\n
"
" {
\"
step
\"
: <推理过程的步骤 1>},
\n
"
" {
\"
step
\"
: <步骤 2>} ...
\n
"
" ],
\n
"
"
\"
answer
\"
: <字符串:最终答案>,
\n
"
"
\"
supporting_facts
\"
: [<支持该答案的事实 1>, <事实 2>, ...],
\n
"
"
\"
type
\"
: <可选:问题类型,如“生物学”、“历史”等>
\n
"
" },
\n
"
" ...
\n
"
" ],
\n
"
"
\"
metadata
\"
: {
\n
"
"
\"
source
\"
: <数据来源>,
\n
"
"
\"
timestamp
\"
: <时间戳字符串>,
\n
"
"
\"
complexity
\"
: <整数:问题复杂度标记>
\n
"
" }
\n
"
"}"
)
else
:
return
(
"MultiHopQAGenerator is a processor for generating multi-hop question-answer pairs from raw text."
,
"It includes preprocessing, information extraction, and reasoning-based QA generation, with configurable LLM backends."
,
"Expected output format:"
,
"Input:
\n
"
"text: <raw input context>"
,
"Output:
\n
"
"{
\n
"
"
\"
text
\"
: <processed input text>,
\n
"
"
\"
qa_pairs
\"
: [
\n
"
" {
\n
"
"
\"
question
\"
: <string: generated question>,
\n
"
"
\"
reasoning_steps
\"
: [
\n
"
" {
\"
step
\"
: <inference step 1>},
\n
"
" {
\"
step
\"
: <inference step 2>} ...
\n
"
" ],
\n
"
"
\"
answer
\"
: <string: final answer>,
\n
"
"
\"
supporting_facts
\"
: [<fact 1>, <fact 2>, ...],
\n
"
"
\"
type
\"
: <optional string: QA category>
\n
"
" },
\n
"
" ...
\n
"
" ],
\n
"
"
\"
metadata
\"
: {
\n
"
"
\"
source
\"
: <source string>,
\n
"
"
\"
timestamp
\"
: <timestamp string>,
\n
"
"
\"
complexity
\"
: <integer: reasoning complexity>
\n
"
" }
\n
"
"}"
)
def
process_text
(
self
,
text
:
str
,
source
:
str
=
"user_input"
)
->
List
[
Dict
[
str
,
Any
]]:
r
"""Process a single text to generate multi-hop QA pairs.
Args:
text (str): The input text to process.
source (str, optional): Source identifier for the text.
(default: :obj:`"user_input"`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
"""
# Convert text to standard format
raw_data
=
[
{
'text'
:
text
,
'source'
:
source
,
}
]
# Construct examples
constructor
=
ExampleConstructor
(
lang
=
self
.
lang
,
llm_serving
=
self
.
llm_serving
,
prompt_template
=
self
.
prompt_template
)
examples
=
constructor
.
construct_examples
(
raw_data
)
# Manage data
# curator = DataCurator(self.config, self.rng)
# final_dataset = curator.curate_dataset(examples)
return
examples
def
process_batch
(
self
,
texts
:
List
[
str
],
sources
:
Optional
[
List
[
str
]]
=
None
)
->
List
[
Dict
[
str
,
Any
]]:
r
"""Process multiple texts in batch to generate multi-hop QA pairs.
Args:
texts (List[str]): List of input texts to process.
sources (Optional[List[str]], optional): List of source
identifiers. (default: :obj:`None`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
Raises:
ValueError: If length of sources doesn't match length of texts.
"""
if
sources
is
None
:
sources
=
[
"default_source"
]
*
len
(
texts
)
elif
len
(
sources
)
!=
len
(
texts
):
raise
ValueError
(
"Length of sources must match length of texts"
)
raw_data
=
[
{
'text'
:
text
,
'source'
:
source
,
}
for
text
,
source
in
zip
(
texts
,
sources
)
]
# Construct examples
constructor
=
ExampleConstructor
(
lang
=
self
.
lang
,
llm_serving
=
self
.
llm_serving
,
)
examples
=
constructor
.
construct_examples
(
raw_data
)
# # Manage data
# curator = DataCurator(self.config, self.rng)
# final_dataset = curator.curate_dataset(examples)
return
examples
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[
self
.
output_key
]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
raise
ValueError
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
run
(
self
,
storage
:
DataFlowStorage
=
None
,
input_key
:
str
=
'chunk_path'
,
output_key
:
str
=
'enhanced_chunk_path'
,
):
self
.
input_key
,
self
.
output_key
=
input_key
,
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
chunk_paths
=
dataframe
[
self
.
input_key
].
tolist
()
for
chunk_path
in
chunk_paths
:
if
(
chunk_path
):
texts
=
[]
if
str
(
chunk_path
).
endswith
(
".json"
):
with
open
(
chunk_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
texts
=
[
item
[
"cleaned_chunk"
]
for
item
in
data
]
elif
str
(
chunk_path
).
endswith
(
".jsonl"
):
with
open
(
chunk_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
data
=
[
json
.
loads
(
line
)
for
line
in
f
]
texts
=
[
item
[
"cleaned_chunk"
]
for
item
in
data
]
else
:
print
(
f
"Unsupported file format:
{
chunk_path
}
"
)
continue
# 生成 QA 对
qa_pairs_batch
=
self
.
process_batch
(
texts
)
# 写入到原数据中
for
item
,
qa_pairs
in
zip
(
data
,
qa_pairs_batch
):
item
[
"qa_pairs"
]
=
qa_pairs
# 回写到原始文件中(覆盖写入)
with
open
(
chunk_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
if
str
(
chunk_path
).
endswith
(
".json"
):
json
.
dump
(
data
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
else
:
# jsonl
for
item
in
data
:
f
.
write
(
json
.
dumps
(
item
,
ensure_ascii
=
False
)
+
"
\n
"
)
self
.
logger
.
info
(
f
"constructed
{
len
(
qa_pairs
)
}
multihop QA for
{
chunk_path
}
"
)
dataframe
[
self
.
output_key
]
=
chunk_paths
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Results saved to
{
output_file
}
"
)
return
[
output_key
]
class
ExampleConstructor
:
r
"""Constructs training examples from raw text data.
This class handles the construction of training examples by preprocessing
text, extracting information pairs, and generating question-answer pairs.
"""
def
__init__
(
self
,
lang
:
str
=
"en"
,
llm_serving
:
LLMServingABC
=
None
,
min_text_length
:
int
=
100
,
max_text_length
:
int
=
200000
,
prompt_template
=
None
):
r
"""Initialize the ExampleConstructor.
Args:
config (ProcessorConfig): Configuration for example construction.
multi_hop_agent (Optional[MultiHopGeneratorAgent], optional):
Agent for generating multi-hop QA pairs. (default: :obj:`None`)
"""
self
.
lang
=
lang
self
.
llm_sering
=
llm_serving
self
.
logger
=
get_logger
()
self
.
max_length
=
max_text_length
self
.
min_length
=
min_text_length
# self.prompt = Text2MultiHopQAGeneratorPrompt(lang=self.lang)
if
prompt_template
:
self
.
prompt_template
=
prompt_template
else
:
self
.
prompt_template
=
Text2MultiHopQAGeneratorPrompt
()
def
construct_examples
(
self
,
raw_data
:
List
[
Dict
[
str
,
Any
]]
)
->
List
[
Dict
[
str
,
Any
]]:
r
"""Construct training examples from raw data.
Args:
raw_data (List[Dict[str, Any]]): List of raw data dictionaries
containing text and metadata.
Returns:
List[Dict[str, Any]]: List of constructed examples with QA pairs
and metadata.
"""
self
.
logger
.
info
(
"Starting to construct examples..."
)
examples
=
[]
for
data
in
tqdm
(
raw_data
,
desc
=
"Constructing examples"
):
# 1. Text preprocessing
processed_text
=
self
.
_preprocess_text
(
data
.
get
(
'text'
,
''
))
if
not
processed_text
:
example
=
{
# 'text': processed_text,
'qa_pairs'
:
[],
'metadata'
:
{
'source'
:
data
.
get
(
'source'
,
'unknown'
),
'timestamp'
:
data
.
get
(
'timestamp'
,
''
),
'complexity'
:
0
,
},
}
examples
.
append
(
example
)
continue
# 2. Generate key information pairs
info_pairs
=
self
.
_extract_info_pairs
(
processed_text
)
# 3. Construct question-answer pairs
if
(
info_pairs
):
qa_pairs
=
self
.
_generate_qa_pairs
(
info_pairs
)
else
:
qa_pairs
=
[]
# 4. Add metadata
example
=
{
# 'text': processed_text,
'qa_pairs'
:
qa_pairs
,
'metadata'
:
{
'source'
:
data
.
get
(
'source'
,
'unknown'
),
'timestamp'
:
data
.
get
(
'timestamp'
,
''
),
'complexity'
:
self
.
_calculate_complexity
(
qa_pairs
)
if
qa_pairs
else
0
,
},
}
examples
.
append
(
example
)
# self.logger.info(f"Successfully constructed {len(examples)} examples")
return
examples
def
_preprocess_text
(
self
,
text
:
str
)
->
str
:
r
"""Preprocess input text for example construction.
Args:
text (str): Input text to preprocess.
Returns:
str: Preprocessed text, or empty string if text fails quality
checks.
"""
if
not
isinstance
(
text
,
str
):
return
''
# 1. Basic cleaning
text
=
text
.
strip
()
# 2. Length check
if
(
len
(
text
)
<
self
.
min_length
or
len
(
text
)
>
self
.
max_length
):
self
.
logger
.
warning
(
"text fail to pass length check."
)
return
''
# 3. Quality check
if
not
self
.
_check_text_quality
(
text
):
self
.
logger
.
warning
(
"text fail to pass quality check."
)
return
''
return
text
def
_calculate_special_char_ratio
(
self
,
text
):
# 中文字符的Unicode范围(基本汉字+扩展)
chinese_ranges
=
[
(
0x4E00
,
0x9FFF
),
# 基本汉字
(
0x3400
,
0x4DBF
),
# 扩展A
(
0x20000
,
0x2A6DF
),
# 扩展B
(
0x2A700
,
0x2B73F
),
# 扩展C
(
0x2B740
,
0x2B81F
),
# 扩展D
(
0x2B820
,
0x2CEAF
)
# 扩展E
]
special_count
=
0
for
c
in
text
:
# 检查是否为中文、字母数字或空格
is_chinese
=
any
(
start
<=
ord
(
c
)
<=
end
for
start
,
end
in
chinese_ranges
)
if
not
(
c
.
isalnum
()
or
c
.
isspace
()
or
is_chinese
):
special_count
+=
1
return
special_count
/
len
(
text
)
if
text
else
0
def
_check_text_quality
(
self
,
text
:
str
)
->
bool
:
r
"""Check the quality of input text.
Args:
text (str): Text to check quality for.
Returns:
bool: True if text passes quality checks, False otherwise.
"""
# 1. Basic quality check
if
(
self
.
lang
==
"en"
and
text
.
count
(
'.'
)
<
2
):
# Must have at least 2 sentences
return
False
elif
(
self
.
lang
in
[
"zh"
,
"ch"
]
and
text
.
count
(
"。"
)
<
2
):
return
False
# 2. Special character ratio check
special_char_ratio
=
self
.
_calculate_special_char_ratio
(
text
)
if
special_char_ratio
>
0.3
:
# No more than 30% special characters
return
False
return
True
def
_extract_info_pairs
(
self
,
text
:
str
)
->
List
[
Dict
[
str
,
Sequence
[
str
]]]:
r
"""Extract information pairs and relationships from text.
Args:
text (str): Input text to extract information from.
Returns:
List[Dict[str, Sequence[str]]]: List of dictionaries containing
premise, intermediate, conclusion, and related contexts.
"""
# Split into sentences
if
(
self
.
lang
==
"en"
):
sentences
=
[
s
.
strip
()
for
s
in
text
.
split
(
'.'
)
if
s
.
strip
()]
else
:
sentences
=
[
s
.
strip
()
for
s
in
text
.
split
(
'。'
)
if
s
.
strip
()]
info_pairs
=
[]
# Extract combinations of multiple related sentences
for
i
in
range
(
len
(
sentences
)
-
2
):
if
len
(
sentences
[
i
])
>
10
and
len
(
sentences
[
i
+
1
])
>
10
:
info_pairs
.
append
(
{
'premise'
:
sentences
[
i
],
'intermediate'
:
sentences
[
i
+
1
],
'conclusion'
:
sentences
[
i
+
2
]
if
i
+
2
<
len
(
sentences
)
else
''
,
'related_contexts'
:
[
s
for
j
,
s
in
enumerate
(
sentences
)
if
j
!=
i
and
j
!=
i
+
1
and
len
(
s
)
>
10
][:
2
],
# Limit to 2 additional related contexts
}
)
return
info_pairs
def
_generate_qa_pairs
(
self
,
info_pairs
:
List
[
Dict
[
str
,
Sequence
[
str
]]]
)
->
List
[
Dict
[
str
,
str
]]:
r
"""Generate multi-hop question-answer pairs from information pairs.
Args:
info_pairs (List[Dict[str, Sequence[str]]]): List of information
pairs extracted from text.
Returns:
List[Dict[str, str]]: List of generated QA pairs.
"""
user_inputs
=
[]
for
pair
in
info_pairs
:
# 1. Generate multi-hop question-answer pair using AI
# Construct full context
context
=
(
f
"
{
pair
[
'premise'
]
}
.
{
pair
[
'intermediate'
]
}
."
f
"
{
pair
[
'conclusion'
]
}
"
)
user_inputs
.
append
(
self
.
prompt_template
.
build_prompt
(
context
))
sys_prompt
=
self
.
prompt_template
.
build_system_prompt
()
responses
=
self
.
llm_sering
.
generate_from_input
(
user_inputs
=
user_inputs
,
system_prompt
=
sys_prompt
)
qa_pairs
=
self
.
_extract_qa_pairs
(
responses
)
return
qa_pairs
def
_extract_qa_pairs
(
self
,
responses
:
List
[
str
])
->
List
[
Dict
[
str
,
Any
]]:
"""
从原始响应中精确提取符合结构的QA对
自动跳过非法JSON和干扰文本
"""
qa_pairs
=
[]
for
response
in
responses
:
# self.logger.info(f"generated qa: {response}")
# 方法1:尝试直接解析整个响应为JSON
try
:
qa_pair
=
json
.
loads
(
response
)
if
isinstance
(
qa_pair
,
dict
)
and
"question"
in
qa_pair
:
qa_pairs
.
append
(
qa_pair
)
continue
elif
isinstance
(
qa_pair
,
list
):
for
item
in
qa_pair
:
if
isinstance
(
item
,
dict
)
and
"question"
in
item
:
qa_pairs
.
append
(
item
)
continue
except
json
.
JSONDecodeError
:
pass
# 方法2:使用正则表达式查找所有JSON对象
try
:
# 查找所有以 { 开始的JSON对象
json_pattern
=
r
'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
# 更精确的模式,匹配完整的JSON对象
brace_count
=
0
start_pos
=
-
1
json_objects
=
[]
for
i
,
char
in
enumerate
(
response
):
if
char
==
'{'
:
if
brace_count
==
0
:
start_pos
=
i
brace_count
+=
1
elif
char
==
'}'
:
brace_count
-=
1
if
brace_count
==
0
and
start_pos
!=
-
1
:
json_str
=
response
[
start_pos
:
i
+
1
]
json_objects
.
append
(
json_str
)
start_pos
=
-
1
# 尝试解析找到的每个JSON字符串
for
json_str
in
json_objects
:
try
:
qa_pair
=
json
.
loads
(
json_str
)
if
(
isinstance
(
qa_pair
,
dict
)
and
"question"
in
qa_pair
and
"reasoning_steps"
in
qa_pair
and
"answer"
in
qa_pair
and
"supporting_facts"
in
qa_pair
and
"type"
in
qa_pair
):
qa_pairs
.
append
(
qa_pair
)
# self.logger.info(
# f"Successfully extracted QA pair: {qa_pair['question']}")
except
json
.
JSONDecodeError
as
e
:
self
.
logger
.
debug
(
f
"Failed to parse JSON object:
{
json_str
[:
100
]
}
... Error:
{
e
}
"
)
continue
# 对qa_pairs中重复的question进行去重
if
qa_pairs
:
seen_questions
=
set
()
unique_qa_pairs
=
[]
for
qa_pair
in
qa_pairs
:
question
=
qa_pair
.
get
(
"question"
,
""
).
strip
().
lower
()
if
question
and
question
not
in
seen_questions
:
seen_questions
.
add
(
question
)
unique_qa_pairs
.
append
(
qa_pair
)
self
.
logger
.
debug
(
f
"Added unique question:
{
qa_pair
[
'question'
]
}
"
)
else
:
self
.
logger
.
debug
(
f
"Skipped duplicate question:
{
qa_pair
.
get
(
'question'
,
'N/A'
)
}
"
)
qa_pairs
=
unique_qa_pairs
# self.logger.info(
# f"After deduplication: {len(qa_pairs)} unique QA pairs")
# 如果没有找到有效的JSON对象,记录警告
if
not
json_objects
:
self
.
logger
.
warning
(
"No JSON objects found in model response."
)
except
Exception
as
e
:
self
.
logger
.
warning
(
f
"Failed to parse QA information from model response. Error:
{
e
}
"
)
return
qa_pairs
def
_calculate_complexity
(
self
,
qa_pairs
:
List
[
Dict
[
str
,
Any
]])
->
float
:
r
"""Calculate the complexity score for a set of QA pairs.
Args:
qa_pairs (List[Dict[str, Any]]): List of QA pairs to calculate
complexity for.
Returns:
float: Complexity score between 0.0 and 1.0.
"""
if
not
qa_pairs
:
return
0.0
# Calculate complexity based on multiple factors
complexities
=
[]
for
qa
in
qa_pairs
:
# 1. Number of reasoning steps
reasoning_steps_count
=
len
(
qa
.
get
(
'reasoning_steps'
,
[]))
# 2. Number of supporting facts
supporting_facts_count
=
len
(
qa
.
get
(
'supporting_facts'
,
[]))
# 3. Question length
question_length
=
len
(
qa
.
get
(
'question'
,
''
).
split
())
# 4. Answer length
answer_length
=
len
(
qa
.
get
(
'answer'
,
''
).
split
())
# Calculate complexity of a single QA pair
qa_complexity
=
(
min
(
reasoning_steps_count
/
3
,
1.0
)
*
0.4
# Weight for reasoning steps
+
min
(
supporting_facts_count
/
3
,
1.0
)
*
0.3
# Weight for supporting facts
+
min
(
question_length
/
20
,
1.0
)
*
0.15
# Weight for question length
+
min
(
answer_length
/
50
,
1.0
)
*
0.15
# Weight for answer length
)
complexities
.
append
(
qa_complexity
)
return
sum
(
complexities
)
/
len
(
complexities
)
dataflow/operators/knowledge_cleaning/generate/kbc_text_cleaner.py
0 → 100644
View file @
97e8278b
from
dataflow.prompts.kbcleaning
import
KnowledgeCleanerPrompt
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
from
dataflow.core
import
LLMServingABC
from
dataflow.core.prompt
import
prompt_restrict
,
DIYPromptABC
from
typing
import
Union
import
re
@
prompt_restrict
(
KnowledgeCleanerPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
KBCTextCleaner
(
OperatorABC
):
'''
KnowledgeCleaner is a class that cleans knowledge for RAG to make them more accurate, reliable and readable.
'''
def
__init__
(
self
,
llm_serving
:
LLMServingABC
,
lang
=
"en"
,
prompt_template
:
Union
[
KnowledgeCleanerPrompt
,
DIYPromptABC
]
=
None
):
self
.
logger
=
get_logger
()
self
.
prompts
=
KnowledgeCleanerPrompt
(
lang
=
lang
)
self
.
llm_serving
=
llm_serving
if
prompt_template
:
self
.
prompt_template
=
prompt_template
else
:
self
.
prompt_template
=
KnowledgeCleanerPrompt
()
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"知识清洗算子:对原始知识内容进行标准化处理,包括HTML标签清理、特殊字符规范化、"
"链接处理和结构优化,提升RAG知识库的质量。主要功能:
\n
"
"1. 移除冗余HTML标签但保留语义化标签
\n
"
"2. 标准化引号/破折号等特殊字符
\n
"
"3. 处理超链接同时保留文本
\n
"
"4. 保持原始段落结构和代码缩进
\n
"
"5. 确保事实性内容零修改
\n
"
"
\n
输入格式示例:
\n
"
"<div class=
\"
container
\"
>
\n
"
" <h1>标题文本</h1>
\n
"
" <p>正文段落,包括特殊符号,例如“弯引号”、–破折号等</p>
\n
"
" <img src=
\"
example.jpg
\"
alt=
\"
示意图
\"
>
\n
"
" <a href=
\"
...
\"
>链接文本</a>
\n
"
" <pre><code>代码片段</code></pre>
\n
"
" ...
\n
"
"</div>
\n
"
"
\n
输出格式示例:
\n
"
"标题文本
\n\n
"
"正文段落,包括特殊符号,例如
\"
直引号
\"
、-破折号等
\n\n
"
"[Image: 示例图 example.jpg]
\n\n
"
"链接文本
\n\n
"
"<code>代码片段</code>
\n\n
"
"[结构保持,语义保留,敏感信息脱敏处理(如手机号、保密标记等)]"
)
elif
lang
==
"en"
:
return
(
"Knowledge Cleaning Operator: Standardizes raw HTML/text content for RAG quality improvement. Key functions:
\n
"
"1. Removes redundant HTML tags while preserving semantic tags
\n
"
"2. Normalizes special characters (e.g., curly quotes, dashes)
\n
"
"3. Processes hyperlinks and retains their text
\n
"
"4. Preserves paragraph structure and code indentation
\n
"
"5. Ensures factual content remains unchanged
\n
"
"
\n
Example Input Format:
\n
"
"<div class=
\"
container
\"
>
\n
"
" <h1>Title Text</h1>
\n
"
" <p>Paragraph with “curly quotes” and – dashes</p>
\n
"
" <img src=
\"
example.jpg
\"
alt=
\"
Diagram
\"
>
\n
"
" <a href=
\"
...
\"
>Link text</a>
\n
"
" <pre><code>Code block</code></pre>
\n
"
" ...
\n
"
"</div>
\n
"
"
\n
Example Output Format:
\n
"
"Title Text
\n\n
"
"Paragraph with
\"
straight quotes
\"
and - dashes
\n\n
"
"[Image: Diagram example.jpg]
\n\n
"
"Link text
\n\n
"
"<code>Code block</code>
\n\n
"
"[Structure retained, semantics preserved, sensitive info masked (e.g., phone numbers, confidential tags)]"
)
else
:
return
"Knowledge cleaning operator for RAG content standardization. Set lang='zh' or 'en' for examples."
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[
self
.
output_key
]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
raise
ValueError
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
_reformat_prompt
(
self
,
dataframe
):
"""
Reformat the prompts in the dataframe to generate questions.
"""
raw_contents
=
dataframe
[
self
.
input_key
].
tolist
()
inputs
=
[
self
.
prompt_template
.
build_prompt
(
raw_content
)
for
raw_content
in
raw_contents
]
return
inputs
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
"raw_chunk"
,
output_key
:
str
=
"cleaned_chunk"
):
'''
Runs the knowledge cleaning process, reading from the input key and saving results to output key.
'''
self
.
input_key
,
self
.
output_key
=
input_key
,
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
formatted_prompts
=
self
.
_reformat_prompt
(
dataframe
)
cleaned
=
self
.
llm_serving
.
generate_from_input
(
formatted_prompts
,
""
)
#for each in cleaned, only save the content in <cleaned_start> and <cleaned_end>
cleaned_extracted
=
[
str
(
text
).
split
(
'<cleaned_start>'
)[
1
].
split
(
'<cleaned_end>'
)[
0
].
strip
()
if
'<cleaned_start>'
in
str
(
text
)
and
'<cleaned_end>'
in
str
(
text
)
else
str
(
text
).
strip
()
for
text
in
cleaned
]
dataframe
[
self
.
output_key
]
=
cleaned_extracted
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Results saved to
{
output_file
}
"
)
return
[
output_key
]
\ No newline at end of file
dataflow/operators/knowledge_cleaning/generate/kbc_text_cleaner_batch.py
0 → 100644
View file @
97e8278b
from
dataflow.prompts.kbcleaning
import
KnowledgeCleanerPrompt
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
import
json
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
from
dataflow.core
import
LLMServingABC
from
dataflow.core.prompt
import
prompt_restrict
,
DIYPromptABC
from
typing
import
Union
import
re
@
prompt_restrict
(
KnowledgeCleanerPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
KBCTextCleanerBatch
(
OperatorABC
):
'''
KnowledgeCleaner is a class that cleans knowledge for RAG to make them more accurate, reliable and readable.
'''
def
__init__
(
self
,
llm_serving
:
LLMServingABC
,
lang
=
"en"
,
prompt_template
:
Union
[
KnowledgeCleanerPrompt
,
DIYPromptABC
]
=
None
):
self
.
logger
=
get_logger
()
self
.
prompts
=
KnowledgeCleanerPrompt
(
lang
=
lang
)
self
.
llm_serving
=
llm_serving
if
prompt_template
:
self
.
prompt_template
=
prompt_template
else
:
self
.
prompt_template
=
KnowledgeCleanerPrompt
(
lang
=
lang
)
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"知识清洗算子:对原始知识内容进行标准化处理,包括HTML标签清理、特殊字符规范化、"
"链接处理和结构优化,提升RAG知识库的质量。主要功能:
\n
"
"1. 移除冗余HTML标签但保留语义化标签
\n
"
"2. 标准化引号/破折号等特殊字符
\n
"
"3. 处理超链接同时保留文本
\n
"
"4. 保持原始段落结构和代码缩进
\n
"
"5. 确保事实性内容零修改"
)
elif
lang
==
"en"
:
return
(
"Knowledge Cleaning Operator: Standardizes raw content for RAG by:
\n
"
"1. Removing redundant HTML tags while preserving semantic markup
\n
"
"2. Normalizing special characters (quotes/dashes)
\n
"
"3. Processing hyperlinks with text preservation
\n
"
"4. Maintaining original paragraph structure and code indentation
\n
"
"5. Guaranteeing zero modification of factual content"
)
else
:
return
"Knowledge cleaning operator for RAG content standardization"
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[
self
.
output_key
]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
raise
ValueError
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
_reformat_prompt
(
self
,
dataframe
):
"""
Reformat the prompts in the dataframe to generate questions.
"""
raw_contents
=
dataframe
[
self
.
input_key
].
tolist
()
inputs
=
[
self
.
prompt_template
.
build_prompt
(
raw_content
)
for
raw_content
in
raw_contents
]
return
inputs
def
_reformat_prompt_from_path
(
self
,
chunk_path
:
str
)
->
list
:
"""
Reformat the prompts in the file (JSON or JSONL) to generate question prompts.
Args:
chunk_path (str): Path to the .json or .jsonl file containing raw chunks.
Returns:
list: A list of formatted prompt strings.
"""
if
chunk_path
.
endswith
(
".json"
):
dataframe
=
pd
.
read_json
(
chunk_path
)
elif
chunk_path
.
endswith
(
".jsonl"
):
dataframe
=
pd
.
read_json
(
chunk_path
,
lines
=
True
)
else
:
raise
ValueError
(
"Unsupported file format. Only .json and .jsonl are supported."
)
if
"raw_chunk"
not
in
dataframe
.
columns
:
raise
KeyError
(
"'raw_chunk' field not found in the input file."
)
raw_contents
=
dataframe
[
"raw_chunk"
].
tolist
()
inputs
=
[
self
.
prompts
.
build_prompt
(
raw_content
)
for
raw_content
in
raw_contents
]
return
raw_contents
,
inputs
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
"chunk_path"
,
output_key
:
str
=
"cleaned_chunk_path"
):
'''
Runs the knowledge cleaning process, reading from the input key and saving results to output key.
'''
self
.
input_key
,
self
.
output_key
=
input_key
,
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
chunk_paths
=
dataframe
[
self
.
input_key
].
tolist
()
for
chunk_path
in
chunk_paths
:
if
(
chunk_path
):
raw_chunks
,
formatted_prompts
=
self
.
_reformat_prompt_from_path
(
chunk_path
)
cleaned
=
self
.
llm_serving
.
generate_from_input
(
formatted_prompts
,
""
)
# for each in cleaned, only save the content in <cleaned_start> and <cleaned_end>
cleaned_extracted
=
[
text
.
split
(
'<cleaned_start>'
)[
1
].
split
(
'<cleaned_end>'
)[
0
].
strip
()
if
'<cleaned_start>'
in
str
(
text
)
and
'<cleaned_end>'
in
str
(
text
)
else
str
(
text
).
strip
()
for
text
in
cleaned
]
json_items
=
[{
"raw_chunk"
:
raw_chunk
,
"cleaned_chunk"
:
cleaned_chunk
}
for
raw_chunk
,
cleaned_chunk
in
zip
(
raw_chunks
,
cleaned_extracted
)]
with
open
(
chunk_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
json_items
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
self
.
logger
.
info
(
f
"Successfully cleaned contents in
{
chunk_path
}
"
)
dataframe
[
self
.
output_key
]
=
chunk_paths
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Results saved to
{
output_file
}
"
)
return
[
output_key
]
dataflow/operators/knowledge_cleaning/generate/mathbook_question_extract.py
0 → 100644
View file @
97e8278b
import
sys
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.core
import
OperatorABC
import
os
from
pathlib
import
Path
import
json
import
shutil
import
fitz
# pip install pymupdf
from
dataflow.prompts.kbcleaning
import
MathbookQuestionExtractPrompt
import
re
from
openai
import
OpenAI
import
base64
from
typing
import
Literal
,
Union
from
dataflow.core
import
LLMServingABC
from
dataflow.serving
import
APIVLMServing_openai
from
dataflow.core.prompt
import
DIYPromptABC
from
dataflow.utils.storage
import
DataFlowStorage
@
OPERATOR_REGISTRY
.
register
()
class
MathBookQuestionExtract
(
OperatorABC
):
def
__init__
(
self
,
llm_serving
:
APIVLMServing_openai
,
prompt_template
:
Union
[
MathbookQuestionExtractPrompt
,
DIYPromptABC
]
=
MathbookQuestionExtractPrompt
(),
mineru_backend
:
str
=
"vlm-vllm-engine"
,
dpi
:
int
=
300
,
key_name_of_api_key
:
str
=
"DF_API_KEY"
,
model_name
:
str
=
"o4-mini"
,
max_workers
:
int
=
20
):
self
.
logger
=
get_logger
()
self
.
llm_serving
=
llm_serving
self
.
prompt_template
=
prompt_template
self
.
mineru_backend
=
mineru_backend
self
.
dpi
=
dpi
self
.
key_name_of_api_key
=
key_name_of_api_key
self
.
model_name
=
model_name
self
.
max_workers
=
max_workers
# 注意:这个参数在原逻辑中并未被使用,但仍按要求移入init
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于从数学教材PDF中提取问题和相关图片内容。它将PDF转换为图片,使用MinerU进行内容提取,"
"然后组织图片并使用大语言模型分析内容,最终生成包含问题和图片的JSON和Markdown文件。
\n
"
"输入参数:
\n
"
"- llm_serving:VLM服务对象,需实现APIVLMServing_openai接口
\n
"
"- pdf_file_path:PDF文件路径
\n
"
"- output_file_name:输出文件名
\n
"
"- output_folder:输出文件夹路径
\n
"
"- MinerU_Backend:MinerU后端类型,默认为'vlm-sglang-engine'
\n
"
"- dpi:PDF转图片的分辨率,默认为300
\n
"
"- api_url:API服务URL
\n
"
"- key_name_of_api_key:API密钥的环境变量名
\n
"
"- model_name:使用的模型名称,默认为'o4-mini'
\n
"
"- max_workers:最大并行工作线程数,默认为20
\n
"
"输出参数:
\n
"
"- 返回布尔值表示处理是否成功
\n
"
"- 在指定文件夹生成JSON和Markdown格式的提取结果"
)
elif
lang
==
"en"
:
return
(
"This operator extracts questions and related images from mathematics textbook PDFs. It converts the PDF to images, "
"uses MinerU for content extraction, organizes the images, and analyzes the content using a large vision-language model, "
"ultimately generating JSON and Markdown files containing questions and images.
\n
"
"Input Parameters:
\n
"
"- llm_serving: VLM serving object implementing APIVLMServing_openai interface
\n
"
"- pdf_file_path: Path to the PDF file
\n
"
"- output_file_name: Name for the output files
\n
"
"- output_folder: Path to the output folder
\n
"
"- MinerU_Backend: MinerU backend type, default is 'vlm-sglang-engine'
\n
"
"- dpi: Resolution for PDF to image conversion, default is 300
\n
"
"- api_url: API service URL
\n
"
"- key_name_of_api_key: Environment variable name for API key
\n
"
"- model_name: Model name to use, default is 'o4-mini'
\n
"
"- max_workers: Maximum number of parallel workers, default is 20
\n\n
"
"Output Parameters:
\n
"
"- Returns boolean indicating success of processing
\n
"
"- Generates extraction results in JSON and Markdown formats in the specified folder"
)
else
:
return
(
"MathBookQuestionExtract processes mathematics textbook PDFs to extract questions and images using MinerU and VLM."
)
def
mineru2_runner
(
self
,
pdf_file_path
:
str
,
output_folder
:
str
,
# pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client
mineru_backend
:
Literal
[
"pipeline"
,
"vlm-transformers"
,
"vlm-vllm-engine"
,
"vlm-http-client"
]
=
"pipeline"
):
try
:
import
mineru
except
ImportError
:
raise
Exception
(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
os
.
environ
[
'MINERU_MODEL_SOURCE'
]
=
"local"
# 可选:从本地加载模型
MinerU_Version
=
{
"pipeline"
:
"auto"
,
"vlm-transformers"
:
"vlm"
,
'vlm-vllm-engine'
:
'vlm'
,
'vlm-http-client'
:
'vlm'
}
raw_file
=
Path
(
pdf_file_path
)
pdf_name
=
raw_file
.
stem
intermediate_dir
=
output_folder
try
:
return_code
=
os
.
system
(
f
"mineru -p
{
raw_file
}
-o
{
intermediate_dir
}
-b
{
mineru_backend
}
--source local"
)
if
return_code
!=
0
:
raise
RuntimeError
(
f
"MinerU execution failed with return code:
{
return_code
}
"
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to process file with MinerU:
{
str
(
e
)
}
"
)
output_file
=
os
.
path
.
join
(
intermediate_dir
,
pdf_name
,
MinerU_Version
[
mineru_backend
],
f
"
{
pdf_name
}
_content_list.json"
)
output_pic_folder
=
os
.
path
.
join
(
intermediate_dir
,
pdf_name
,
MinerU_Version
[
mineru_backend
],
"images"
)
self
.
logger
.
info
(
f
"MinerU json file has been saved to
{
output_file
}
"
)
return
output_file
,
output_pic_folder
def
organize_pics
(
self
,
mineru_content_json_path
:
str
,
mineru_image_folder
:
str
,
output_file_path
:
str
,
output_pic_folder
:
str
):
'''
用来把mineru切割出来的图片组织到最终文件夹下的辅助函数
输入:
mineru_content_json_path: mineru切割出来的json文件路径
mineru_image_folder: mineru切割出来的图片文件夹路径
输出:
output_file_path: 组织图片后的图片信息记录文件,服务后续的图片处理
output_pic_folder: 最终组织后的图片文件夹路径
'''
global_counter
=
0
global_json_data
=
[]
# read mineru content json
json_data
=
json
.
load
(
open
(
mineru_content_json_path
,
'r'
))
# if output_pic_folder is not exist, create it
if
not
os
.
path
.
exists
(
output_pic_folder
):
os
.
makedirs
(
output_pic_folder
)
for
item
in
json_data
:
if
item
[
'type'
]
==
'image'
:
# get the image name
image_name
=
item
[
'img_path'
].
split
(
'/'
)[
-
1
]
# get the image path
image_path
=
os
.
path
.
join
(
mineru_image_folder
,
image_name
)
page_idx
=
item
[
'page_idx'
]
# rename the image
new_image_name
=
f
"
{
global_counter
}
.jpg"
new_image_path
=
os
.
path
.
join
(
output_pic_folder
,
new_image_name
)
shutil
.
copy
(
image_path
,
new_image_path
)
# add to global json data
global_json_data
.
append
({
"img_path"
:
new_image_path
,
"page_idx"
:
page_idx
,
})
global_counter
+=
1
# write to json file
with
open
(
output_file_path
,
'w'
)
as
f
:
json
.
dump
(
global_json_data
,
f
,
indent
=
4
)
def
pdf2images
(
self
,
pdf_path
:
str
,
output_folder
:
str
,
dpi
:
int
=
300
):
'''
用来把pdf文件转换为图片的辅助函数
输入:
pdf_path: pdf文件路径
output_folder: 输出图片文件夹路径
'''
doc
=
fitz
.
open
(
pdf_path
)
# make output directory if it doesn't exist
os
.
makedirs
(
output_folder
,
exist_ok
=
True
)
# convert each page to image
for
page_index
in
range
(
len
(
doc
)):
page
=
doc
.
load_page
(
page_index
)
pix
=
page
.
get_pixmap
(
dpi
=
dpi
)
pix
.
save
(
f
"
{
output_folder
}
/page_
{
page_index
}
.jpg"
)
self
.
logger
.
info
(
f
"Converted page
{
page_index
}
to image"
)
return
True
def
encode_image_to_base64
(
self
,
image_path
:
str
)
->
str
:
with
open
(
image_path
,
"rb"
)
as
f
:
return
base64
.
b64encode
(
f
.
read
()).
decode
(
"utf-8"
)
def
process_input
(
self
,
page_folder
:
str
,
img_json_path
:
str
):
# 加载page_folder内所有的page_n.jpg
page_list
=
[
os
.
path
.
join
(
page_folder
,
f
)
for
f
in
os
.
listdir
(
page_folder
)
if
f
.
endswith
((
'.jpg'
))]
idx_list
=
[
int
(
f
.
split
(
"/"
)[
-
1
].
split
(
"."
)[
0
].
split
(
"_"
)[
-
1
])
for
f
in
page_list
]
max_page_idx
=
max
(
idx_list
)
# load img_json
img_json
=
json
.
load
(
open
(
img_json_path
,
"r"
))
img_dict
=
{}
for
item
in
img_json
:
if
item
[
"page_idx"
]
not
in
img_dict
:
img_dict
[
item
[
"page_idx"
]]
=
[]
img_dict
[
item
[
"page_idx"
]].
append
(
item
[
"img_path"
])
full_input_image_list
=
[]
full_input_label_list
=
[]
for
page_idx
in
range
(
max_page_idx
):
image_list
=
[]
label_list
=
[]
image_list
.
append
(
os
.
path
.
join
(
page_folder
,
f
"page_
{
page_idx
}
.jpg"
))
label_list
.
append
(
f
"page_
{
page_idx
}
"
)
image_list
.
append
(
os
.
path
.
join
(
page_folder
,
f
"page_
{
page_idx
+
1
}
.jpg"
))
label_list
.
append
(
f
"page_
{
page_idx
+
1
}
"
)
if
page_idx
in
img_dict
:
image_list
.
extend
(
img_dict
[
page_idx
])
label_list
.
extend
([
img_path
.
split
(
"/"
)[
-
1
]
for
img_path
in
img_dict
[
page_idx
]])
if
page_idx
+
1
in
img_dict
:
image_list
.
extend
(
img_dict
[
page_idx
+
1
])
label_list
.
extend
([
img_path
.
split
(
"/"
)[
-
1
]
for
img_path
in
img_dict
[
page_idx
+
1
]])
full_input_image_list
.
append
(
image_list
)
full_input_label_list
.
append
(
label_list
)
return
full_input_image_list
,
full_input_label_list
def
analyze_and_save
(
self
,
result_list
,
save_folder
,
img_folder
,
output_file_name
):
# ... (analyze_and_save 方法保持不变)
# make save_folder if not exist
if
not
os
.
path
.
exists
(
save_folder
):
os
.
makedirs
(
save_folder
)
# make save_folder/images if not exist
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
save_folder
,
"images"
)):
os
.
makedirs
(
os
.
path
.
join
(
save_folder
,
"images"
))
output_json
=
[]
output_markdown_text
=
""
for
item
in
result_list
:
if
not
item
:
continue
split_text
=
item
.
split
(
"<SPACE>"
)
for
text
in
split_text
:
if
not
text
:
continue
# 检查所有形如<image>index.jpg</image>这样的内容,比如<image>1.jpg</image>,严格匹配<image>*.jpg</image>
pic_list
=
[]
pic_match
=
re
.
findall
(
r
'<image>(.*?)\.jpg</image>'
,
text
)
if
pic_match
:
for
pic_name
in
pic_match
:
# 传入完整路径
pic_list
.
append
(
os
.
path
.
join
(
img_folder
,
f
"
{
pic_name
}
.jpg"
))
# 生成json风格tezt:直接删掉所有<image>index.jpg</image>
json_text
=
re
.
sub
(
r
'<image>(.*?)\.jpg</image>'
,
''
,
text
)
# 生成markdown风格text:把<image>index.jpg</image>替换为
markdown_text
=
text
for
pic_name
in
pic_match
:
# 把img_folder/pic_name.jpg copy 到 save_folder/images/pic_name.jpg
shutil
.
copy
(
os
.
path
.
join
(
img_folder
,
f
"
{
pic_name
}
.jpg"
),
os
.
path
.
join
(
save_folder
,
"images"
,
f
"
{
pic_name
}
.jpg"
))
markdown_text
=
markdown_text
.
replace
(
f
"<image>
{
pic_name
}
.jpg</image>"
,
f
""
)
else
:
json_text
=
text
markdown_text
=
text
pic_list
=
[]
json_data
=
{
"text"
:
json_text
,
"pics"
:
pic_list
}
output_json
.
append
(
json_data
)
output_markdown_text
+=
markdown_text
+
"
\n
"
+
"---"
+
"
\n
"
# save output_json to save_folder
with
open
(
os
.
path
.
join
(
save_folder
,
f
"
{
output_file_name
}
.json"
),
"w"
)
as
f
:
json
.
dump
(
output_json
,
f
,
indent
=
4
,
ensure_ascii
=
False
)
# save output_markdown_text to save_folder
with
open
(
os
.
path
.
join
(
save_folder
,
f
"
{
output_file_name
}
.md"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
output_markdown_text
)
return
output_json
,
output_markdown_text
def
run
(
self
,
storage
:
DataFlowStorage
,
input_pdf_file_path
:
str
,
output_file_name
:
str
,
output_folder
:
str
,
):
# get the configuration parameters from self
api_key
=
os
.
environ
.
get
(
self
.
key_name_of_api_key
)
if
not
api_key
:
raise
ValueError
(
f
"API key not found in environment variable
{
self
.
key_name_of_api_key
}
"
)
# 1. convert pdf to images
pdf2images_folder_name
=
os
.
path
.
join
(
output_folder
,
"pdfimages"
)
self
.
pdf2images
(
input_pdf_file_path
,
pdf2images_folder_name
,
self
.
dpi
)
# 2. use mineru to extract content and pics
json_content_file
,
pic_folder
=
self
.
mineru2_runner
(
input_pdf_file_path
,
output_folder
,
self
.
mineru_backend
)
# 3. organize_pics
output_image_folder
=
os
.
path
.
join
(
output_folder
,
"organized_images"
)
output_json_file
=
os
.
path
.
join
(
output_image_folder
,
"organized_info.json"
)
self
.
organize_pics
(
json_content_file
,
pic_folder
,
output_json_file
,
output_image_folder
)
# 4. process input
full_input_image_list
,
full_input_label_list
=
self
.
process_input
(
pdf2images_folder_name
,
output_json_file
)
# 5. init server and generate
system_prompt
=
self
.
prompt_template
.
build_prompt
()
result_text_list
=
self
.
llm_serving
.
generate_from_input_multi_images
(
list_of_image_paths
=
full_input_image_list
,
list_of_image_labels
=
full_input_label_list
,
system_prompt
=
system_prompt
,
model
=
self
.
model_name
,
timeout
=
1800
)
# 6. save responses
self
.
analyze_and_save
(
result_text_list
,
output_folder
,
output_image_folder
,
output_file_name
)
# 7. return
return
True
dataflow/operators/knowledge_cleaning/generate/qa_extract.py
0 → 100644
View file @
97e8278b
#!/usr/bin/env python3
"""QA Extractor - 提取QA对并转换为Alpaca格式"""
import
json
from
pathlib
import
Path
from
typing
import
Optional
,
List
from
dataflow.core
import
OperatorABC
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow
import
get_logger
@
OPERATOR_REGISTRY
.
register
()
class
QAExtractor
(
OperatorABC
):
"""
从QA_pairs字段提取问答对,转换为Alpaca微调格式
Input: QA_pairs (nested structure)
Output: instruction, input, output (Alpaca format)
"""
def
__init__
(
self
,
qa_key
:
str
=
"QA_pairs"
,
output_json_file
:
Optional
[
str
]
=
None
,
instruction
:
str
=
"Please answer the following question based on the provided information."
):
self
.
logger
=
get_logger
()
self
.
qa_key
=
qa_key
self
.
output_json_file
=
output_json_file
self
.
instruction
=
instruction
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
"""获取算子描述"""
if
lang
==
"zh"
:
return
(
"QA对提取器 - 将嵌套的QA_pairs转换为Alpaca微调格式
\n\n
"
"核心功能:
\n
"
"从结构化的QA对数据中提取问答内容,自动整合推理步骤和支持事实,
\n
"
"输出符合Stanford Alpaca标准的instruction-input-output格式。
\n\n
"
"初始化参数:
\n
"
"• qa_key: QA对的字段名 (默认: 'QA_pairs')
\n
"
"• output_json_file: 输出JSON文件路径 (可选,不指定则只更新DataFrame)
\n
"
"• instruction: 统一的指令前缀 (默认: 'Please answer the following question...')
\n\n
"
"运行参数 (input_key):
\n
"
"• None - 包含所有字段 (question + reasoning_steps + supporting_facts)
\n
"
"• '' - 空字符串,不包含额外上下文
\n
"
"• 'reasoning_steps' - 只包含推理步骤
\n
"
"• 'question,reasoning_steps' - 逗号分隔多个字段
\n
"
"• ['question', 'supporting_facts'] - 列表格式
\n\n
"
"输出字段:
\n
"
"• instruction: 问题指令
\n
"
"• input: 上下文信息 (根据input_key动态拼接)
\n
"
"• output: 答案
\n\n
"
"适用场景: 知识库QA微调、领域问答模型训练"
)
else
:
# English
return
(
"QA Extractor - Convert nested QA_pairs to Alpaca fine-tuning format
\n\n
"
"Core Function:
\n
"
"Extract question-answer pairs from structured data, automatically integrate
\n
"
"reasoning steps and supporting facts, output in Stanford Alpaca standard
\n
"
"instruction-input-output format.
\n\n
"
"Initialization Parameters:
\n
"
"• qa_key: Field name for QA pairs (default: 'QA_pairs')
\n
"
"• output_json_file: Output JSON path (optional, skip to only update DataFrame)
\n
"
"• instruction: Unified instruction prefix (default: 'Please answer...')
\n\n
"
"Runtime Parameters (input_key):
\n
"
"• None - Include all fields (question + reasoning_steps + supporting_facts)
\n
"
"• '' - Empty string, no additional context
\n
"
"• 'reasoning_steps' - Only reasoning steps
\n
"
"• 'question,reasoning_steps' - Comma-separated fields
\n
"
"• ['question', 'supporting_facts'] - List format
\n\n
"
"Output Fields:
\n
"
"• instruction: Question as instruction
\n
"
"• input: Context information (dynamically assembled by input_key)
\n
"
"• output: Answer
\n\n
"
"Use Cases: Knowledge base QA fine-tuning, domain-specific Q&A training"
)
def
_parse_fields
(
self
,
input_key
:
Optional
[
str
])
->
Optional
[
List
[
str
]]:
"""解析要包含的字段"""
if
input_key
is
None
:
return
None
# 包含所有
if
isinstance
(
input_key
,
list
):
return
input_key
if
isinstance
(
input_key
,
str
):
return
[
f
.
strip
()
for
f
in
input_key
.
split
(
','
)
if
f
.
strip
()]
if
input_key
.
strip
()
else
[]
return
None
def
_extract_qa
(
self
,
row
,
fields
:
Optional
[
List
[
str
]]
=
None
)
->
List
[
dict
]:
"""从单行提取QA对"""
qa_data
=
row
.
get
(
self
.
qa_key
)
if
not
qa_data
:
return
[]
# 支持嵌套结构
qa_list
=
qa_data
.
get
(
'qa_pairs'
,
[])
if
isinstance
(
qa_data
,
dict
)
else
qa_data
if
not
isinstance
(
qa_list
,
list
):
return
[]
results
=
[]
default_fields
=
[
'question'
,
'reasoning_steps'
,
'supporting_facts'
]
fields
=
fields
if
fields
is
not
None
else
default_fields
for
qa
in
qa_list
:
if
not
isinstance
(
qa
,
dict
):
continue
question
=
qa
.
get
(
'question'
,
''
).
strip
()
answer
=
qa
.
get
(
'answer'
,
''
).
strip
()
if
not
question
or
not
answer
:
continue
# 构建input
parts
=
[]
for
field
in
fields
:
if
field
==
'question'
:
parts
.
append
(
f
"Question:
{
question
}
"
)
elif
field
==
'reasoning_steps'
and
qa
.
get
(
'reasoning_steps'
):
if
parts
:
parts
.
append
(
""
)
parts
.
append
(
"Reasoning Process:"
)
for
i
,
step
in
enumerate
(
qa
[
'reasoning_steps'
],
1
):
text
=
step
.
get
(
'step'
,
step
)
if
isinstance
(
step
,
dict
)
else
str
(
step
)
if
text
:
parts
.
append
(
f
"
{
i
}
.
{
text
}
"
)
elif
field
==
'supporting_facts'
and
qa
.
get
(
'supporting_facts'
):
if
parts
:
parts
.
append
(
""
)
parts
.
append
(
"Supporting Information:"
)
for
fact
in
qa
[
'supporting_facts'
]:
text
=
fact
.
get
(
'fact'
,
fact
)
if
isinstance
(
fact
,
dict
)
else
str
(
fact
)
if
text
:
parts
.
append
(
f
"-
{
text
}
"
)
elif
field
in
qa
and
qa
[
field
]:
if
parts
:
parts
.
append
(
""
)
parts
.
append
(
f
"
{
field
}
:
{
qa
[
field
]
}
"
)
results
.
append
({
'instruction'
:
self
.
instruction
,
'input'
:
"
\n
"
.
join
(
parts
),
'output'
:
answer
})
return
results
def
_load_from_files
(
self
,
df
):
"""从chunk文件加载QA数据"""
import
pandas
as
pd
path_keys
=
[
'enhanced_chunk_path'
,
'cleaned_chunk_path'
,
'chunk_path'
]
path_col
=
next
((
k
for
k
in
path_keys
if
k
in
df
.
columns
),
None
)
if
not
path_col
:
raise
ValueError
(
f
"需要这些字段之一:
{
path_keys
}
"
)
rows
=
[]
for
_
,
row
in
df
.
iterrows
():
file_path
=
row
[
path_col
]
if
not
file_path
or
not
Path
(
file_path
).
exists
():
continue
try
:
with
open
(
file_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
chunks
=
json
.
load
(
f
)
chunks
=
chunks
if
isinstance
(
chunks
,
list
)
else
[
chunks
]
for
chunk
in
chunks
:
if
self
.
qa_key
in
chunk
:
rows
.
append
({
self
.
qa_key
:
chunk
[
self
.
qa_key
],
'source_file'
:
file_path
})
except
Exception
as
e
:
self
.
logger
.
error
(
f
"加载失败
{
file_path
}
:
{
e
}
"
)
if
not
rows
:
raise
ValueError
(
"未找到有效QA数据"
)
return
pd
.
DataFrame
(
rows
)
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
Optional
[
str
]
=
None
,
output_key
:
Optional
[
str
]
=
None
)
->
List
[
str
]:
"""提取QA对"""
import
pandas
as
pd
self
.
logger
.
info
(
"开始提取QA对..."
)
df
=
storage
.
read
(
output_type
=
"dataframe"
)
# 如果没有QA_pairs,从文件加载
if
self
.
qa_key
not
in
df
.
columns
:
df
=
self
.
_load_from_files
(
df
)
# 提取所有QA对
fields
=
self
.
_parse_fields
(
input_key
)
all_qas
=
[]
for
_
,
row
in
df
.
iterrows
():
all_qas
.
extend
(
self
.
_extract_qa
(
row
,
fields
))
self
.
logger
.
info
(
f
"提取了
{
len
(
all_qas
)
}
个QA对"
)
if
not
all_qas
:
self
.
logger
.
warning
(
"未提取到QA对!"
)
return
[
'instruction'
,
'input'
,
'output'
]
# 保存JSON(可选)
if
self
.
output_json_file
:
output_path
=
Path
(
self
.
output_json_file
)
output_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
open
(
output_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
all_qas
,
f
,
indent
=
2
,
ensure_ascii
=
False
)
self
.
logger
.
info
(
f
"已保存到
{
output_path
}
"
)
# 写回storage
storage
.
write
(
pd
.
DataFrame
(
all_qas
))
return
[
'instruction'
,
'input'
,
'output'
]
\ No newline at end of file
dataflow/operators/pdf2vqa/__init__.py
0 → 100644
View file @
97e8278b
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
.generate.vqa_extractor
import
VQAExtractor
else
:
import
sys
from
dataflow.utils.registry
import
LazyLoader
,
generate_import_structure_from_type_checking
cur_path
=
"dataflow/operators/pdf2vqa/"
_import_structure
=
generate_import_structure_from_type_checking
(
__file__
,
cur_path
)
sys
.
modules
[
__name__
]
=
LazyLoader
(
__name__
,
"dataflow/operators/pdf2vqa/"
,
_import_structure
)
dataflow/operators/pdf2vqa/generate/vqa_extractor.py
0 → 100644
View file @
97e8278b
import
os
import
json
import
re
import
pandas
as
pd
import
tiktoken
import
shutil
import
torch
from
pathlib
import
Path
from
typing
import
Literal
from
dataflow.core
import
OperatorABC
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow
import
get_logger
from
dataflow.core
import
LLMServingABC
from
dataflow.prompts.pdf2vqa
import
QAExtractPrompt
from
dataflow.core.prompt
import
prompt_restrict
from
dataflow.utils.pdf2vqa.format_utils
import
merge_qa_pair
,
jsonl_to_md
@
prompt_restrict
(
QAExtractPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
VQAExtractor
(
OperatorABC
):
def
__init__
(
self
,
llm_serving
:
LLMServingABC
=
None
,
mineru_backend
:
Literal
[
"vlm-transformers"
,
"vlm-vllm-engine"
]
=
"vlm-transformers"
,
max_chunk_len
:
int
=
128000
,):
self
.
logger
=
get_logger
()
self
.
llm_serving
=
llm_serving
self
.
prompt_template
=
QAExtractPrompt
()
self
.
mineru_backend
=
mineru_backend
self
.
max_chunk_len
=
max_chunk_len
def
_convert_json
(
self
,
input_file
,
output_file
):
with
open
(
input_file
,
'r'
)
as
infile
:
data
=
list
(
json
.
load
(
infile
))
new_data
=
[]
id
=
0
for
item
in
data
:
item
[
'id'
]
=
id
item
.
pop
(
'bbox'
,
None
)
item
.
pop
(
'page_idx'
,
None
)
if
item
.
get
(
'type'
,
''
)
==
'list'
:
if
item
[
'sub_type'
]
==
'text'
:
for
idx
,
list_item
in
enumerate
(
item
.
get
(
'list_items'
,
[])):
new_item
=
{
'type'
:
'text'
,
'text'
:
list_item
,
'id'
:
id
+
idx
,
}
new_data
.
append
(
new_item
)
id
+=
len
(
item
.
get
(
'list_items'
,
[]))
else
:
new_data
.
append
(
item
)
id
+=
1
with
open
(
output_file
,
'w'
)
as
outfile
:
json
.
dump
(
new_data
,
outfile
,
ensure_ascii
=
False
)
def
_count_tokens
(
self
,
text
:
str
)
->
int
:
enc
=
tiktoken
.
get_encoding
(
"cl100k_base"
)
return
len
(
enc
.
encode
(
text
))
def
_id_to_text
(
self
,
input_ids
,
input_json
,
image_prefix
=
"images"
):
texts
=
[]
id_list
=
input_ids
.
replace
(
' '
,
''
).
split
(
','
)
for
id
in
id_list
:
try
:
int
(
id
)
except
:
continue
if
int
(
id
)
<
len
(
input_json
):
try
:
item
=
input_json
[
int
(
id
)]
except
:
continue
if
'text'
in
item
:
texts
.
append
(
item
[
'text'
])
elif
'img_path'
in
item
:
try
:
img_path
=
item
.
get
(
'img_path'
,
''
)
img_name
=
os
.
path
.
basename
(
img_path
)
new_path
=
f
"
{
image_prefix
}
/
{
img_name
}
"
texts
.
append
(
f
""
)
except
:
pass
elif
item
.
get
(
'type'
,
''
)
==
'list'
:
if
item
[
'sub_type'
]
==
'text'
:
try
:
texts
.
append
(
input_json
[
int
(
id
)][
'list_items'
].
pop
(
0
))
except
:
pass
return
'
\n
'
.
join
(
texts
)
def
_extract_doc_layout
(
self
,
input_pdf_file_path
:
str
,
output_folder
:
str
,
mineru_backend
:
Literal
[
"vlm-transformers"
,
"vlm-vllm-engine"
]
=
"vlm-transformers"
):
"""提取 PDF 的布局信息(合并自 VQAExtractDocLayoutMinerU)"""
try
:
import
mineru
from
mineru.cli.client
import
main
as
mineru_main
except
ImportError
:
raise
Exception
(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
try
:
from
pypdf
import
PdfReader
,
PdfWriter
,
PageObject
except
ImportError
:
raise
Exception
(
"""
pypdf is not installed in this environment yet.
Please use pip install pypdf.
"""
)
try
:
from
reportlab.pdfgen
import
canvas
except
ImportError
:
raise
Exception
(
"""
reportlab is not installed in this environment yet.
Please use pip install reportlab.
"""
)
os
.
environ
[
'MINERU_MODEL_SOURCE'
]
=
"local"
MinerU_Version
=
{
"pipeline"
:
"auto"
,
"vlm-transformers"
:
"vlm"
,
"vlm-vllm-engine"
:
"vlm"
}
if
mineru_backend
==
"pipeline"
:
raise
ValueError
(
"The 'pipeline' backend is not supported due to its incompatible output format. Please use 'vlm-transformers' or 'vlm-vllm-engine' instead."
)
raw_file
=
Path
(
input_pdf_file_path
)
pdf_name
=
raw_file
.
stem
intermediate_dir
=
output_folder
args
=
[
"-p"
,
str
(
raw_file
),
"-o"
,
str
(
intermediate_dir
),
"-b"
,
mineru_backend
,
"--source"
,
"local"
]
if
mineru_backend
==
"vlm-vllm-engine"
:
assert
torch
.
cuda
.
is_available
(),
"MinerU vlm-vllm-engine backend requires GPU support."
args
+=
[
"--tensor-parallel-size"
,
"2"
if
torch
.
cuda
.
device_count
()
>=
2
else
"1"
]
try
:
mineru_main
(
args
)
except
SystemExit
as
e
:
if
e
.
code
!=
0
:
raise
RuntimeError
(
f
"MinerU execution failed with exit code:
{
e
.
code
}
"
)
output_json_file
=
os
.
path
.
join
(
intermediate_dir
,
pdf_name
,
MinerU_Version
[
mineru_backend
],
f
"
{
pdf_name
}
_content_list.json"
)
output_layout_file
=
os
.
path
.
join
(
intermediate_dir
,
pdf_name
,
MinerU_Version
[
mineru_backend
],
f
"
{
pdf_name
}
_layout.pdf"
)
return
output_json_file
,
output_layout_file
def
_convert_response
(
self
,
input_response
,
input_json_path
,
image_prefix
=
"images"
):
qa_list
=
[]
with
open
(
input_json_path
,
'r'
)
as
infile
:
input_json
=
list
(
json
.
load
(
infile
))
# 提取title
for
chapter_block
in
re
.
findall
(
r
'<chapter>(.*?)</chapter>'
,
input_response
,
flags
=
re
.
DOTALL
):
title
=
re
.
search
(
r
'<title>(.*?)</title>'
,
chapter_block
,
flags
=
re
.
DOTALL
)
if
title
:
chapter_title
=
self
.
_id_to_text
(
title
.
group
(
1
).
strip
(),
input_json
,
image_prefix
)
else
:
chapter_title
=
""
# 找出所有 qa_pair 块
for
pair
in
re
.
findall
(
r
'<qa_pair>(.*?)</qa_pair>'
,
chapter_block
,
flags
=
re
.
DOTALL
):
# 提取 question 部分
q_match
=
re
.
search
(
r
'<question>(.*?)</question>'
,
pair
,
flags
=
re
.
DOTALL
)
# 提取 answer 部分
a_match
=
re
.
search
(
r
'<answer>(.*?)</answer>'
,
pair
,
flags
=
re
.
DOTALL
)
# 提取solution部分
s_match
=
re
.
search
(
r
'<solution>(.*?)</solution>'
,
pair
,
flags
=
re
.
DOTALL
)
# 提取label
label_match
=
re
.
search
(
r
'<label>(.*?)</label>'
,
pair
,
flags
=
re
.
DOTALL
)
if
not
((
q_match
and
label_match
)
or
(
a_match
and
label_match
)
or
(
s_match
and
label_match
)):
continue
label
=
label_match
.
group
(
1
).
strip
()
qa_list
.
append
({
'question'
:
self
.
_id_to_text
(
q_match
.
group
(
1
).
strip
(),
input_json
,
image_prefix
)
if
q_match
else
""
,
'answer'
:
a_match
.
group
(
1
).
strip
()
if
a_match
else
""
,
'solution'
:
self
.
_id_to_text
(
s_match
.
group
(
1
).
strip
(),
input_json
,
image_prefix
)
if
s_match
else
""
,
'label'
:
label
,
'chapter_title'
:
chapter_title
})
return
qa_list
def
run
(
self
,
storage
:
DataFlowStorage
,
input_question_pdf_path_key
:
str
=
"question_pdf_path"
,
input_answer_pdf_path_key
:
str
=
"answer_pdf_path"
,
input_pdf_path_key
:
str
=
"pdf_path"
,
# 支持 interleaved 模式的单一 pdf_path
input_subject_key
:
str
=
"subject"
,
output_dir_key
:
str
=
"output_dir"
,
output_jsonl_key
:
str
=
"output_jsonl_path"
,
output_default_dir
:
str
=
"../vqa_output"
)
->
list
:
dataframe
=
storage
.
read
(
"dataframe"
)
# 支持两种输入格式:question_pdf_path/answer_pdf_path 或 pdf_path
if
input_question_pdf_path_key
not
in
dataframe
.
columns
and
input_pdf_path_key
not
in
dataframe
.
columns
:
raise
ValueError
(
f
"Column '
{
input_question_pdf_path_key
}
' or '
{
input_pdf_path_key
}
' not found in dataframe"
)
# ========== Stage 1: 预处理(任务扩展 + Layout 提取) ==========
expanded_rows
=
[]
for
idx
,
row
in
dataframe
.
iterrows
():
# 优先使用 question_pdf_path,如果没有则使用 pdf_path(interleaved 模式)
if
input_question_pdf_path_key
in
dataframe
.
columns
:
question_pdf_path
=
row
[
input_question_pdf_path_key
]
answer_pdf_path
=
row
.
get
(
input_answer_pdf_path_key
,
question_pdf_path
)
else
:
# interleaved 模式:使用同一个 pdf_path
question_pdf_path
=
row
[
input_pdf_path_key
]
answer_pdf_path
=
question_pdf_path
subject
=
row
.
get
(
input_subject_key
,
"math"
)
output_root
=
row
.
get
(
output_dir_key
,
output_default_dir
)
interleaved
=
(
question_pdf_path
==
answer_pdf_path
)
os
.
makedirs
(
output_root
,
exist_ok
=
True
)
# Question task
q_outdir
=
os
.
path
.
join
(
output_root
,
"question"
)
os
.
makedirs
(
q_outdir
,
exist_ok
=
True
)
# Layout 提取
q_json_path
,
_
=
self
.
_extract_doc_layout
(
input_pdf_file_path
=
question_pdf_path
,
output_folder
=
q_outdir
,
mineru_backend
=
self
.
mineru_backend
)
expanded_rows
.
append
({
"pdf_path"
:
question_pdf_path
,
"mode"
:
"question"
,
"interleaved"
:
interleaved
,
"subject"
:
subject
,
"output_dir"
:
q_outdir
,
"output_root"
:
output_root
,
"json_path"
:
q_json_path
})
# Answer task (if not interleaved)
if
not
interleaved
:
a_outdir
=
os
.
path
.
join
(
output_root
,
"answer"
)
os
.
makedirs
(
a_outdir
,
exist_ok
=
True
)
# Layout 提取
a_json_path
,
_
=
self
.
_extract_doc_layout
(
input_pdf_file_path
=
answer_pdf_path
,
output_folder
=
a_outdir
,
mineru_backend
=
self
.
mineru_backend
)
expanded_rows
.
append
({
"pdf_path"
:
answer_pdf_path
,
"mode"
:
"answer"
,
"interleaved"
:
interleaved
,
"subject"
:
subject
,
"output_dir"
:
a_outdir
,
"output_root"
:
output_root
,
"json_path"
:
a_json_path
})
# ========== Stage 2: QA 提取 ==========
json_paths
=
[
row
[
"json_path"
]
for
row
in
expanded_rows
]
subjects
=
[
row
[
"subject"
]
for
row
in
expanded_rows
]
user_inputs
=
[]
split_metadata
=
[]
for
idx
,
input_json_path
in
enumerate
(
json_paths
):
subject
=
subjects
[
idx
]
if
idx
<
len
(
subjects
)
else
subjects
[
0
]
if
subjects
else
"math"
system_prompt
=
self
.
prompt_template
.
build_prompt
(
subject
)
system_prompt_len
=
self
.
_count_tokens
(
system_prompt
)
converted_path
=
input_json_path
.
replace
(
'.json'
,
'_converted.json'
)
self
.
_convert_json
(
input_json_path
,
converted_path
)
with
open
(
converted_path
,
'r'
)
as
infile
:
data
=
json
.
load
(
infile
)
assert
isinstance
(
data
,
list
),
f
"Expected list, got
{
type
(
data
)
}
for
{
input_json_path
}
"
# 分段处理
current_chunk
,
current_len
=
[],
system_prompt_len
chunks
=
[]
for
item
in
data
:
text
=
json
.
dumps
(
item
,
ensure_ascii
=
False
)
item_len
=
self
.
_count_tokens
(
text
)
if
current_len
+
item_len
>
self
.
max_chunk_len
and
current_chunk
:
chunks
.
append
(
current_chunk
)
current_chunk
,
current_len
=
[],
system_prompt_len
current_chunk
.
append
(
item
)
current_len
+=
item_len
if
current_chunk
:
chunks
.
append
(
current_chunk
)
split_metadata
.
append
(
len
(
chunks
))
for
chunk
in
chunks
:
user_inputs
.
append
({
'user_input'
:
json
.
dumps
(
chunk
,
ensure_ascii
=
False
),
'system_prompt'
:
system_prompt
})
# 批量生成
responses
=
[
None
]
*
len
(
user_inputs
)
current_batch
=
[]
current_batch_indices
=
[]
current_system_prompt
=
None
for
idx
,
item
in
enumerate
(
user_inputs
):
user_input
=
item
[
'user_input'
]
system_prompt
=
item
[
'system_prompt'
]
if
current_system_prompt
is
None
:
current_system_prompt
=
system_prompt
current_batch
=
[
user_input
]
current_batch_indices
=
[
idx
]
elif
system_prompt
==
current_system_prompt
:
current_batch
.
append
(
user_input
)
current_batch_indices
.
append
(
idx
)
else
:
# 处理当前批次
batch_responses
=
self
.
llm_serving
.
generate_from_input
(
user_inputs
=
current_batch
,
system_prompt
=
current_system_prompt
)
for
batch_idx
,
resp
in
zip
(
current_batch_indices
,
batch_responses
):
responses
[
batch_idx
]
=
resp
# 开始新批次
current_system_prompt
=
system_prompt
current_batch
=
[
user_input
]
current_batch_indices
=
[
idx
]
# 处理最后一批
if
current_batch
:
batch_responses
=
self
.
llm_serving
.
generate_from_input
(
user_inputs
=
current_batch
,
system_prompt
=
current_system_prompt
)
for
batch_idx
,
resp
in
zip
(
current_batch_indices
,
batch_responses
):
responses
[
batch_idx
]
=
resp
# 按 split_metadata 还原
recombined_responses
=
[]
idx
=
0
for
num_chunks
in
split_metadata
:
merged_text
=
"
\n
"
.
join
(
responses
[
idx
:
idx
+
num_chunks
])
recombined_responses
.
append
(
merged_text
)
idx
+=
num_chunks
# ========== Stage 3: 后处理(Response 转换 + 合并和过滤) ==========
# Response 转换
qa_lists
=
[]
for
idx
,
(
response
,
row
)
in
enumerate
(
zip
(
recombined_responses
,
expanded_rows
)):
json_path
=
row
[
"json_path"
]
output_dir
=
row
[
"output_dir"
]
mode
=
row
[
"mode"
]
output_root
=
row
[
"output_root"
]
image_prefix
=
f
"
{
mode
}
_images"
converted_json_path
=
json_path
.
replace
(
'.json'
,
'_converted.json'
)
qa_list
=
self
.
_convert_response
(
response
,
converted_json_path
,
image_prefix
)
# 复制图片
src_dir
=
os
.
path
.
join
(
output_dir
,
Path
(
json_path
).
stem
).
replace
(
'_content_list'
,
''
)
src_images
=
os
.
path
.
join
(
src_dir
,
'vlm'
,
'images'
)
dst_images
=
os
.
path
.
join
(
output_root
,
image_prefix
)
try
:
if
os
.
path
.
exists
(
src_images
):
if
os
.
path
.
exists
(
dst_images
):
shutil
.
rmtree
(
dst_images
)
shutil
.
copytree
(
src_images
,
dst_images
)
else
:
self
.
logger
.
warning
(
f
"Source images dir does not exist:
{
src_images
}
"
)
except
Exception
as
e
:
self
.
logger
.
warning
(
f
"Failed to copy images from
{
src_images
}
to
{
dst_images
}
:
{
e
}
"
)
qa_lists
.
append
(
qa_list
)
# 按 output_root 分组处理合并和过滤
output_groups
=
{}
for
idx
,
(
qa_list
,
row
)
in
enumerate
(
zip
(
qa_lists
,
expanded_rows
)):
output_root
=
row
[
"output_root"
]
mode
=
row
[
"mode"
]
interleaved
=
row
[
"interleaved"
]
output_dir
=
row
[
"output_dir"
]
if
output_root
not
in
output_groups
:
output_groups
[
output_root
]
=
{
"question"
:
None
,
"answer"
:
None
,
"interleaved"
:
interleaved
}
if
mode
==
"question"
:
output_groups
[
output_root
][
"question"
]
=
(
qa_list
,
output_dir
)
elif
mode
==
"answer"
:
output_groups
[
output_root
][
"answer"
]
=
(
qa_list
,
output_dir
)
# 处理每个 output_root
result_paths_dict
=
{}
for
output_root
,
group_info
in
output_groups
.
items
():
q_qa_list
,
q_output_dir
=
group_info
[
"question"
]
if
group_info
[
"question"
]
else
(
None
,
None
)
a_qa_list
,
a_output_dir
=
group_info
[
"answer"
]
if
group_info
[
"answer"
]
else
(
None
,
None
)
interleaved
=
group_info
[
"interleaved"
]
# 写入 question jsonl
q_jsonl_path
=
os
.
path
.
join
(
output_root
,
"vqa_extracted_questions.jsonl"
)
if
q_qa_list
:
with
open
(
q_jsonl_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
item
in
q_qa_list
:
f
.
write
(
json
.
dumps
(
item
,
ensure_ascii
=
False
)
+
'
\n
'
)
# 写入 answer jsonl(如果不是 interleaved)
a_jsonl_path
=
None
if
not
interleaved
and
a_qa_list
:
a_jsonl_path
=
os
.
path
.
join
(
output_root
,
"vqa_extracted_answers.jsonl"
)
with
open
(
a_jsonl_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
item
in
a_qa_list
:
f
.
write
(
json
.
dumps
(
item
,
ensure_ascii
=
False
)
+
'
\n
'
)
# 合并
merged_jsonl
=
os
.
path
.
join
(
output_root
,
"vqa_merged_qa_pairs.jsonl"
)
if
not
interleaved
and
a_jsonl_path
:
merge_qa_pair
(
q_jsonl_path
,
a_jsonl_path
,
merged_jsonl
)
else
:
os
.
system
(
f
"cp
{
q_jsonl_path
}
{
merged_jsonl
}
"
)
# 过滤
filtered_items
=
[]
total_count
=
0
with
open
(
merged_jsonl
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
total_count
+=
1
item
=
json
.
loads
(
line
)
if
item
.
get
(
'question'
,
''
).
strip
()
and
(
item
.
get
(
'answer'
,
''
).
strip
()
or
item
.
get
(
'solution'
,
''
).
strip
()):
filtered_items
.
append
(
item
)
self
.
logger
.
info
(
f
"Before filter:
{
total_count
}
, After filter:
{
len
(
filtered_items
)
}
"
)
filtered_jsonl
=
os
.
path
.
join
(
output_root
,
"vqa_filtered_qa_pairs.jsonl"
)
with
open
(
filtered_jsonl
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
item
in
filtered_items
:
f
.
write
(
json
.
dumps
(
item
,
ensure_ascii
=
False
)
+
'
\n
'
)
# 转换为 markdown
md_output
=
os
.
path
.
join
(
output_root
,
"vqa_filtered_qa_pairs.md"
)
jsonl_to_md
(
filtered_jsonl
,
md_output
)
result_paths_dict
[
output_root
]
=
filtered_jsonl
# 为原始 dataframe 的每一行分配结果路径
result_paths
=
[]
for
idx
,
row
in
dataframe
.
iterrows
():
if
input_question_pdf_path_key
in
dataframe
.
columns
:
question_pdf_path
=
row
[
input_question_pdf_path_key
]
answer_pdf_path
=
row
.
get
(
input_answer_pdf_path_key
,
question_pdf_path
)
else
:
question_pdf_path
=
row
[
input_pdf_path_key
]
answer_pdf_path
=
question_pdf_path
output_root
=
row
.
get
(
output_dir_key
,
output_default_dir
)
result_paths
.
append
(
result_paths_dict
.
get
(
output_root
))
dataframe
[
output_jsonl_key
]
=
result_paths
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"VQA extraction complete. Results saved to
{
output_file
}
"
)
return
[
output_jsonl_key
,]
dataflow/operators/reasoning/__init__.py
0 → 100644
View file @
97e8278b
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
# generate
from
.generate.reasoning_answer_generator
import
ReasoningAnswerGenerator
from
.generate.reasoning_question_generator
import
ReasoningQuestionGenerator
from
.generate.reasoning_answer_extraction_qwenmatheval_generator
import
ReasoningAnswerExtractionQwenMathEvalGenerator
from
.generate.reasoning_pseudo_answer_generator
import
ReasoningPseudoAnswerGenerator
from
.generate.reasoning_pretrain_format_convert_generator
import
ReasoningPretrainFormatConvertGenerator
from
.generate.reasoning_question_fusion_generator
import
ReasoningQuestionFusionGenerator
# eval
from
.eval.reasoning_category_dataset_evaluator
import
ReasoningCategoryDatasetEvaluator
from
.eval.reasoning_difficulty_dataset_evaluator
import
ReasoningDifficultyDatasetEvaluator
from
.eval.reasoning_token_dataset_evaluator
import
ReasoningTokenDatasetEvaluator
from
.eval.reasoning_question_category_sample_evaluator
import
ReasoningQuestionCategorySampleEvaluator
from
.eval.reasoning_question_difficulty_sample_evaluator
import
ReasoningQuestionDifficultySampleEvaluator
from
.eval.reasoning_question_solvable_sample_evaluator
import
ReasoningQuestionSolvableSampleEvaluator
# filter
from
.filter.reasoning_answer_formatter_filter
import
ReasoningAnswerFormatterFilter
from
.filter.reasoning_answer_groundtruth_filter
import
ReasoningAnswerGroundTruthFilter
from
.filter.reasoning_answer_ngram_filter
import
ReasoningAnswerNgramFilter
from
.filter.reasoning_answer_pipeline_root_filter
import
ReasoningAnswerPipelineRootFilter
from
.filter.reasoning_answer_token_length_filter
import
ReasoningAnswerTokenLengthFilter
from
.filter.reasoning_question_filter
import
ReasoningQuestionFilter
from
.filter.reasoning_answer_model_judge_filter
import
ReasoningAnswerModelJudgeFilter
else
:
import
sys
from
dataflow.utils.registry
import
LazyLoader
,
generate_import_structure_from_type_checking
cur_path
=
"dataflow/operators/reasoning/"
_import_structure
=
generate_import_structure_from_type_checking
(
__file__
,
cur_path
)
sys
.
modules
[
__name__
]
=
LazyLoader
(
__name__
,
"dataflow/operators/reasoning/"
,
_import_structure
)
dataflow/operators/reasoning/eval/reasoning_category_dataset_evaluator.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.core
import
OperatorABC
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.utils.reasoning.CategoryFuzz
import
CategoryUtils
import
pandas
as
pd
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningCategoryDatasetEvaluator
(
OperatorABC
):
def
__init__
(
self
):
self
.
logger
=
get_logger
()
self
.
logger
.
info
(
f
'Initializing
{
self
.
__class__
.
__name__
}
...'
)
self
.
logger
.
info
(
f
'
{
self
.
__class__
.
__name__
}
initialized.'
)
self
.
information_name
=
"Category Information"
self
.
category_list
=
CategoryUtils
().
secondary_categories
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于统计数据集中的类别信息,包括主类别和次类别的分布情况。"
"它计算每个类别的样本数量,并返回类别分布的统计结果。
\n
"
"输入参数:
\n
"
"- input_primary_category_key:主类别字段名,默认为'primary_category'
\n
"
"- input_secondary_category_key:次类别字段名,默认为'secondary_category'
\n
"
"输出参数:
\n
"
"- 返回包含类别统计信息的字典,主类别作为键,值为包含该类别样本数量和次类别分布的字典"
)
elif
lang
==
"en"
:
return
(
"This operator analyzes category distribution in the dataset, including primary and secondary categories. "
"It counts the number of samples in each category and returns statistical results of category distribution.
\n
"
"Input Parameters:
\n
"
"- input_primary_category_key: Field name for primary category, default is 'primary_category'
\n
"
"- input_secondary_category_key: Field name for secondary category, default is 'secondary_category'
\n\n
"
"Output Parameters:
\n
"
"- Returns a dictionary containing category statistics, with primary categories as keys and values as dictionaries "
"containing sample counts and secondary category distribution"
)
else
:
return
(
"CategoryInfo analyzes and reports the distribution of primary and secondary categories in the dataset."
)
def
get_category_info
(
self
,
samples
,
input_primary_category_key
=
"primary_category"
,
input_secondary_category_key
=
"secondary_category"
):
primary_categories
=
[
sample
.
get
(
input_primary_category_key
,
''
)
for
sample
in
samples
]
secondary_categories
=
[
sample
.
get
(
input_secondary_category_key
,
''
)
for
sample
in
samples
]
primary_categories_count
=
pd
.
Series
(
primary_categories
).
value_counts
().
to_dict
()
secondary_categories_count
=
pd
.
Series
(
secondary_categories
).
value_counts
().
to_dict
()
output
=
[]
for
primary
in
self
.
category_list
:
js
=
{}
if
primary
not
in
primary_categories_count
:
continue
js
[
"primary_num"
]
=
primary_categories_count
[
primary
]
for
secondary
in
self
.
category_list
[
primary
]:
if
secondary
not
in
secondary_categories_count
:
continue
js
[
secondary
]
=
secondary_categories_count
[
secondary
]
output
[
primary
]
=
js
self
.
logger
.
info
(
f
"Category information:
{
output
}
"
)
return
output
def
run
(
self
,
storage
:
DataFlowStorage
,
input_primary_category_key
:
str
=
"primary_category"
,
input_secondary_category_key
:
str
=
"secondary_category"
):
self
.
input_primary_category_key
=
input_primary_category_key
self
.
input_secondary_category_key
=
input_secondary_category_key
dataframe
=
storage
.
read
(
"dataframe"
)
if
self
.
input_primary_category_key
not
in
dataframe
.
columns
or
self
.
input_secondary_category_key
not
in
dataframe
.
columns
:
self
.
logger
.
error
(
f
"Input keys
{
self
.
input_primary_category_key
}
or
{
self
.
input_secondary_category_key
}
not found in dataframe columns."
)
return
{}
samples
=
dataframe
.
to_dict
(
orient
=
'records'
)
category_info
=
self
.
get_category_info
(
samples
,
self
.
input_primary_category_key
,
self
.
input_secondary_category_key
)
return
category_info
\ No newline at end of file
dataflow/operators/reasoning/eval/reasoning_difficulty_dataset_evaluator.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.core
import
OperatorABC
from
dataflow.utils.storage
import
DataFlowStorage
import
pandas
as
pd
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningDifficultyDatasetEvaluator
(
OperatorABC
):
def
__init__
(
self
):
self
.
logger
=
get_logger
()
self
.
logger
.
info
(
f
'Initializing
{
self
.
__class__
.
__name__
}
...'
)
self
.
logger
.
info
(
f
'
{
self
.
__class__
.
__name__
}
initialized.'
)
self
.
information_name
=
"Difficulty Information"
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于统计数据集中的难度信息,计算不同难度级别的样本数量分布。"
"它统计每个难度级别的样本数量,并返回难度分布的统计结果。
\n
"
"输入参数:
\n
"
"- input_diffulty_key:难度分数字段名,默认为'difficulty_score'
\n
"
"输出参数:
\n
"
"- 返回包含难度统计信息的字典,难度级别作为键,值为该难度级别的样本数量"
)
elif
lang
==
"en"
:
return
(
"This operator analyzes difficulty distribution in the dataset, calculating the number of samples at different difficulty levels. "
"It counts samples at each difficulty level and returns statistical results of difficulty distribution.
\n
"
"Input Parameters:
\n
"
"- input_diffulty_key: Field name for difficulty score, default is 'difficulty_score'
\n\n
"
"Output Parameters:
\n
"
"- Returns a dictionary containing difficulty statistics, with difficulty levels as keys and sample counts as values"
)
else
:
return
(
"DifficultyInfo analyzes and reports the distribution of difficulty levels in the dataset."
)
def
get_category_info
(
self
,
samples
,
input_diffulty_key
=
"difficulty_score"
):
diffultys
=
[
sample
.
get
(
input_diffulty_key
,
'null'
)
for
sample
in
samples
]
diffultys_count
=
pd
.
Series
(
diffultys
).
value_counts
().
to_dict
()
self
.
logger
.
info
(
f
"Difficulty information:
{
diffultys_count
}
"
)
return
diffultys_count
def
run
(
self
,
storage
:
DataFlowStorage
,
input_diffulty_key
:
str
=
"difficulty_score"
):
self
.
input_diffulty_key
=
input_diffulty_key
dataframe
=
storage
.
read
(
"dataframe"
)
if
self
.
input_diffulty_key
not
in
dataframe
.
columns
:
self
.
logger
.
error
(
f
"Input key
{
self
.
input_diffulty_key
}
not found in dataframe columns."
)
return
{}
samples
=
dataframe
.
to_dict
(
orient
=
'records'
)
category_info
=
self
.
get_category_info
(
samples
,
self
.
input_diffulty_key
)
return
category_info
\ No newline at end of file
dataflow/operators/reasoning/eval/reasoning_question_category_sample_evaluator.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
from
dataflow.core.prompt
import
prompt_restrict
from
dataflow.utils.reasoning.CategoryFuzz
import
CategoryUtils
from
dataflow.core
import
LLMServingABC
from
dataflow.prompts.reasoning.math
import
MathQuestionCategoryPrompt
import
pandas
as
pd
import
json
import
re
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
@
prompt_restrict
(
MathQuestionCategoryPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningQuestionCategorySampleEvaluator
(
OperatorABC
):
def
__init__
(
self
,
llm_serving
:
LLMServingABC
=
None
):
"""
Initialize the ReasoningCategoryDatasetEvaluator with the provided configuration.
"""
self
.
logger
=
get_logger
()
self
.
prompts
=
MathQuestionCategoryPrompt
()
self
.
llm_serving
=
llm_serving
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于对用户问题进行多级分类(主分类和子分类)。"
"通过大语言模型对输入问题进行语义分析,输出分类编码结果。
\n\n
"
"输入参数:
\n
"
"- db_port/db_name/table_name:数据库连接参数(存储模式)
\n
"
"- input_file/output_file:文件路径(文件模式)
\n
"
"- input_key:输入数据中问题字段的键名
\n
"
"- generator_type:模型调用方式(aisuite/request)
\n\n
"
"输出参数:
\n
"
"- classification_result:包含主分类和子分类的编码结果"
)
elif
lang
==
"en"
:
return
(
"Performs hierarchical classification (primary and secondary) on user questions. "
"Utilizes LLM for semantic analysis and outputs category codes.
\n\n
"
"Input Parameters:
\n
"
"- db_port/db_name/table_name: Database connection params (storage mode)
\n
"
"- input_file/output_file: File paths (file mode)
\n
"
"- input_key: Key for question field in input data
\n
"
"- generator_type: Model invocation method (aisuite/request)
\n\n
"
"Output Parameters:
\n
"
"- classification_result: Combined category code"
)
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[
self
.
output_key
]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
raise
ValueError
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
_reformat_prompt
(
self
,
dataframe
):
"""
Reformat the prompts in the dataframe to generate questions.
"""
# Check if input_key is in the dataframe
formatted_prompts
=
[]
for
text
in
dataframe
[
self
.
input_key
]:
used_prompt
=
self
.
prompts
.
build_prompt
(
text
)
formatted_prompts
.
append
(
used_prompt
.
strip
())
return
formatted_prompts
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
"instruction"
,
output_key
:
str
=
"question_category"
)
->
None
:
"""
Run the question category classification process.
"""
self
.
input_key
,
self
.
output_key
=
input_key
,
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
formatted_prompts
=
self
.
_reformat_prompt
(
dataframe
)
responses
=
self
.
llm_serving
.
generate_from_input
(
formatted_prompts
)
for
(
idx
,
row
),
classification_str
in
zip
(
dataframe
.
iterrows
(),
responses
):
try
:
if
not
classification_str
:
raise
ValueError
(
"空字符串"
)
# 去除 Markdown 代码块包裹
cleaned_str
=
re
.
sub
(
r
"^```json\s*|\s*```$"
,
""
,
classification_str
.
strip
(),
flags
=
re
.
DOTALL
)
# 去除非 ASCII 字符(可选)
cleaned_str
=
re
.
sub
(
r
"[^\x00-\x7F]+"
,
""
,
cleaned_str
)
classification
=
json
.
loads
(
cleaned_str
)
primary_raw
=
classification
.
get
(
"primary_category"
,
""
)
secondary_raw
=
classification
.
get
(
"secondary_category"
,
""
)
category_info
=
CategoryUtils
().
normalize_categories
(
raw_primary
=
primary_raw
,
raw_secondary
=
secondary_raw
)
dataframe
.
at
[
idx
,
"primary_category"
]
=
category_info
[
"primary_category"
]
dataframe
.
at
[
idx
,
"secondary_category"
]
=
category_info
[
"secondary_category"
]
except
json
.
JSONDecodeError
:
self
.
logger
.
warning
(
f
"[警告] JSON 解析失败,收到的原始数据:
{
repr
(
classification_str
)
}
"
)
except
Exception
as
e
:
self
.
logger
.
error
(
f
"[错误] 解析分类结果失败:
{
e
}
"
)
self
.
logger
.
debug
(
f
"[DEBUG] 原始字符串:
{
repr
(
classification_str
)
}
"
)
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Classification results saved to
{
output_file
}
"
)
return
[
"primary_category"
,
"secondary_category"
]
\ No newline at end of file
dataflow/operators/reasoning/eval/reasoning_question_difficulty_sample_evaluator.py
0 → 100644
View file @
97e8278b
from
dataflow.prompts.reasoning.math
import
MathQuestionDifficultyPrompt
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.core.prompt
import
prompt_restrict
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
from
dataflow.core
import
LLMServingABC
import
pandas
as
pd
import
re
@
prompt_restrict
(
MathQuestionDifficultyPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningQuestionDifficultySampleEvaluator
(
OperatorABC
):
def
__init__
(
self
,
llm_serving
:
LLMServingABC
=
None
):
"""
Initialize the ReasoningCategoryDatasetEvaluator with the provided configuration.
"""
self
.
logger
=
get_logger
()
self
.
prompts
=
MathQuestionDifficultyPrompt
()
self
.
llm_serving
=
llm_serving
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于评估问题的难度等级。"
"通过大语言模型分析问题复杂度,输出1-10级的难度评分。
\n\n
"
"输入参数:
\n
"
"- eval_stage:评估阶段标识
\n
"
"- read_min/max_score:分数过滤阈值
\n
"
"- 其他参数同ReasoningCategoryDatasetEvaluator
\n\n
"
"输出参数:
\n
"
"- difficulty_score:数值型难度评分(1-10)"
)
elif
lang
==
"en"
:
return
(
"Evaluates question difficulty level using LLM analysis. "
"Outputs numerical difficulty score from 1 to 10.
\n\n
"
"Input Parameters:
\n
"
"- eval_stage: Evaluation stage identifier
\n
"
"- read_min/max_score: Score filtering thresholds
\n
"
"- Other params same as ReasoningCategoryDatasetEvaluator
\n\n
"
"Output Parameters:
\n
"
"- difficulty_score: Numerical difficulty rating (1-10)"
)
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
,
input_key
:
str
=
"instruction"
,
output_key
:
str
=
"difficulty_score"
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[
self
.
output_key
]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
raise
ValueError
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
_reformat_prompt
(
self
,
dataframe
,
input_key
:
str
=
"instruction"
)
->
list
:
"""
Reformat the prompts in the dataframe to generate questions.
"""
formatted_prompts
=
[]
for
i
,
text
in
enumerate
(
dataframe
[
input_key
]):
if
text
is
not
None
:
used_prompt
=
self
.
prompts
.
build_prompt
(
text
)
else
:
used_prompt
=
None
formatted_prompts
.
append
(
used_prompt
.
strip
())
return
formatted_prompts
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
,
output_key
:
str
=
"difficulty_score"
)
->
None
:
"""
Run the question difficulty classification process.
"""
self
.
input_key
,
self
.
output_key
=
input_key
,
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
,
input_key
=
self
.
input_key
,
output_key
=
self
.
output_key
)
formatted_prompts
=
self
.
_reformat_prompt
(
dataframe
,
input_key
=
self
.
input_key
)
responses
=
self
.
llm_serving
.
generate_from_input
(
user_inputs
=
formatted_prompts
)
rating_scores
=
[]
for
response
in
responses
:
match
=
re
.
search
(
r
'Rating:\s*((\d+\.\d+)|\d+)'
,
response
)
if
match
:
score_str
=
match
.
group
(
1
).
rstrip
(
'.'
)
try
:
score
=
float
(
score_str
)
except
ValueError
:
score
=
-
1
else
:
score
=
-
1
rating_scores
.
append
(
score
)
dataframe
[
output_key
]
=
rating_scores
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Classification results saved to
{
output_file
}
"
)
return
[
output_key
]
\ No newline at end of file
dataflow/operators/reasoning/eval/reasoning_question_solvable_sample_evaluator.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
from
dataflow.core.prompt
import
prompt_restrict
,
DIYPromptABC
from
dataflow.core
import
LLMServingABC
from
dataflow.prompts.reasoning.math
import
MathQuestionEvaluatorPrompt
from
typing
import
Union
import
pandas
as
pd
import
json
import
re
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
@
prompt_restrict
(
MathQuestionEvaluatorPrompt
)
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningQuestionSolvableSampleEvaluator
(
OperatorABC
):
def
__init__
(
self
,
llm_serving
:
LLMServingABC
=
None
,
prompt_template
:
Union
[
MathQuestionEvaluatorPrompt
,
DIYPromptABC
]
=
None
):
"""
Initialize the ReasoningCategoryDatasetEvaluator with the provided configuration.
"""
self
.
logger
=
get_logger
()
if
prompt_template
is
None
:
prompt_template
=
MathQuestionEvaluatorPrompt
()
self
.
prompt_template
=
prompt_template
self
.
llm_serving
=
llm_serving
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于对用户问题进行多级分类(主分类和子分类)。"
"通过大语言模型对输入问题进行语义分析,输出分类编码结果。
\n\n
"
"输入参数:
\n
"
"- db_port/db_name/table_name:数据库连接参数(存储模式)
\n
"
"- input_file/output_file:文件路径(文件模式)
\n
"
"- input_key:输入数据中问题字段的键名
\n
"
"- generator_type:模型调用方式(aisuite/request)
\n\n
"
"输出参数:
\n
"
"- classification_result:包含主分类和子分类的编码结果"
)
elif
lang
==
"en"
:
return
(
"Performs hierarchical classification (primary and secondary) on user questions. "
"Utilizes LLM for semantic analysis and outputs category codes.
\n\n
"
"Input Parameters:
\n
"
"- db_port/db_name/table_name: Database connection params (storage mode)
\n
"
"- input_file/output_file: File paths (file mode)
\n
"
"- input_key: Key for question field in input data
\n
"
"- generator_type: Model invocation method (aisuite/request)
\n\n
"
"Output Parameters:
\n
"
"- classification_result: Combined category code"
)
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[
self
.
output_key
]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
raise
ValueError
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
raise
ValueError
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
def
_reformat_prompt
(
self
,
dataframe
):
problem
=
dataframe
[
self
.
input_key
].
tolist
()
system_prompt
=
self
.
prompt_template
.
build_system_prompt
()
prompts
=
[
self
.
prompt_template
.
build_prompt
(
p
)
for
p
in
problem
]
return
system_prompt
,
prompts
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
,
output_key
:
str
):
"""
Run the question generation process.
"""
self
.
input_key
,
self
.
output_key
=
input_key
,
output_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
sys_prompts
,
user_prompts
=
self
.
_reformat_prompt
(
dataframe
)
responses
=
self
.
llm_serving
.
generate_from_input
(
user_prompts
,
sys_prompts
)
dataframe
[
f
"
{
output_key
}
"
]
=
responses
self
.
logger
.
info
(
f
"Generated questions for
{
output_key
}
"
)
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Generated questions saved to
{
output_file
}
"
)
\ No newline at end of file
dataflow/operators/reasoning/eval/reasoning_token_dataset_evaluator.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.core
import
OperatorABC
from
dataflow.utils.storage
import
DataFlowStorage
import
pandas
as
pd
from
transformers
import
AutoTokenizer
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningTokenDatasetEvaluator
(
OperatorABC
):
def
__init__
(
self
,
model_name_or_path
:
str
):
self
.
logger
=
get_logger
()
self
.
logger
.
info
(
f
'Initializing
{
self
.
__class__
.
__name__
}
...'
)
self
.
logger
.
info
(
f
'
{
self
.
__class__
.
__name__
}
initialized.'
)
self
.
information_name
=
"Token Information"
self
.
model_name_or_path
=
model_name_or_path
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于统计数据集中问题和回答的token信息,包括token数量的最小值、最大值、平均值和中位数等统计指标。"
"它使用指定的tokenizer对文本进行编码,并计算token长度的分布情况。
\n
"
"输入参数:
\n
"
"- input_question_key:问题文本字段名
\n
"
"- input_answer_key:回答文本字段名
\n
"
"- model_name_or_path:tokenizer模型名称或路径
\n
"
"输出参数:
\n
"
"- 返回包含token统计信息的字典,包括问题和回答的token数量的零值计数、最小值、最大值、平均值和中位数"
)
elif
lang
==
"en"
:
return
(
"This operator analyzes token information for questions and answers in the dataset, including statistical metrics "
"such as minimum, maximum, mean, and median token counts. It encodes text using the specified tokenizer and calculates "
"token length distribution.
\n
"
"Input Parameters:
\n
"
"- input_question_key: Field name for question text
\n
"
"- input_answer_key: Field name for answer text
\n
"
"- model_name_or_path: Tokenizer model name or path
\n\n
"
"Output Parameters:
\n
"
"- Returns a dictionary containing token statistics, including zero count, minimum, maximum, mean, and median token counts "
"for both questions and answers"
)
else
:
return
(
"ToKenInfo analyzes and reports token length statistics for questions and answers in the dataset using a specified tokenizer."
)
def
get_token_info
(
self
,
samples
,
input_question_key
,
input_answer_key
,
model_name_or_path
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
questions
=
[
sample
.
get
(
input_question_key
,
''
)
or
''
for
sample
in
samples
]
answers
=
[
sample
.
get
(
input_answer_key
,
''
)
or
''
for
sample
in
samples
]
questions_tokens_length
=
[
len
(
tokenizer
.
encode
(
question
,
add_special_tokens
=
False
))
for
question
in
questions
]
answers_tokens_length
=
[
len
(
tokenizer
.
encode
(
answer
,
add_special_tokens
=
False
))
for
answer
in
answers
]
# count zeros in questions_tokens_length and answers_tokens_length
questions_zeros_count
=
questions_tokens_length
.
count
(
0
)
answers_zeros_count
=
answers_tokens_length
.
count
(
0
)
# count min,max,mean, median of questions_tokens_length and answers_tokens_length
questions_min
=
min
(
questions_tokens_length
)
if
questions_tokens_length
else
0
questions_max
=
max
(
questions_tokens_length
)
if
questions_tokens_length
else
0
questions_mean
=
sum
(
questions_tokens_length
)
/
len
(
questions_tokens_length
)
if
questions_tokens_length
else
0
questions_median
=
sorted
(
questions_tokens_length
)[
len
(
questions_tokens_length
)
//
2
]
if
questions_tokens_length
else
0
answers_min
=
min
(
answers_tokens_length
)
if
answers_tokens_length
else
0
answers_max
=
max
(
answers_tokens_length
)
if
answers_tokens_length
else
0
answers_mean
=
sum
(
answers_tokens_length
)
/
len
(
answers_tokens_length
)
if
answers_tokens_length
else
0
answers_median
=
sorted
(
answers_tokens_length
)[
len
(
answers_tokens_length
)
//
2
]
if
answers_tokens_length
else
0
token_info
=
{
"questions_zeros_count"
:
questions_zeros_count
,
"answers_zeros_count"
:
answers_zeros_count
,
"questions_min"
:
questions_min
,
"questions_max"
:
questions_max
,
"questions_mean"
:
questions_mean
,
"questions_median"
:
questions_median
,
"answers_min"
:
answers_min
,
"answers_max"
:
answers_max
,
"answers_mean"
:
answers_mean
,
"answers_median"
:
answers_median
}
self
.
logger
.
info
(
f
"Token information:
{
token_info
}
"
)
return
token_info
def
run
(
self
,
storage
:
DataFlowStorage
,
input_question_key
:
str
,
input_answer_key
:
str
):
self
.
input_question_key
=
input_question_key
self
.
input_answer_key
=
input_answer_key
dataframe
=
storage
.
read
(
"dataframe"
)
if
self
.
input_question_key
not
in
dataframe
.
columns
:
self
.
logger
.
error
(
f
"Input key
{
self
.
input_question_key
}
not found in dataframe columns."
)
return
{}
if
self
.
input_answer_key
not
in
dataframe
.
columns
:
self
.
logger
.
warning
(
f
"Input key
{
self
.
input_answer_key
}
not found in dataframe columns"
)
samples
=
dataframe
.
to_dict
(
orient
=
'records'
)
token_info
=
self
.
get_token_info
(
samples
,
self
.
input_question_key
,
self
.
input_answer_key
,
self
.
model_name_or_path
)
return
token_info
\ No newline at end of file
dataflow/operators/reasoning/filter/reasoning_answer_formatter_filter.py
0 → 100644
View file @
97e8278b
import
numpy
as
np
import
pandas
as
pd
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
OperatorABC
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningAnswerFormatterFilter
(
OperatorABC
):
def
__init__
(
self
):
self
.
logger
=
get_logger
()
def
is_valid_answer
(
answer
:
str
)
->
bool
:
# check final answer in \boxed{} or not
# if not re.search(r'\\boxed{.*}', answer):
# return False
return
True
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于检查答案格式是否符合规范,主要验证数学答案是否包含正确的
\\
boxed{}标记。
\n\n
"
"输入参数:
\n
"
"- input_key:输入字段名
\n
"
"- result_key:结果字段名
\n\n
"
"输出参数:
\n
"
"- 通过格式检查返回1,否则返回0"
)
elif
lang
==
"en"
:
return
(
"This operator validates answer formatting, specifically checking for correct
\\
boxed{} notation.
\n\n
"
"Input Parameters:
\n
"
"- input_key: Field name containing the answer
\n
"
"- result_key: Output result field name
\n\n
"
"Output Parameters:
\n
"
"- Returns 1 for valid format, 0 otherwise"
)
else
:
return
"AnswerFormatterFilter validates mathematical answer formatting"
def
_validate_dataframe
(
self
,
dataframe
:
pd
.
DataFrame
):
required_keys
=
[
self
.
input_key
]
forbidden_keys
=
[]
missing
=
[
k
for
k
in
required_keys
if
k
not
in
dataframe
.
columns
]
conflict
=
[
k
for
k
in
forbidden_keys
if
k
in
dataframe
.
columns
]
if
missing
:
self
.
logger
.
error
(
f
"Missing required column(s):
{
missing
}
"
)
if
conflict
:
self
.
logger
.
error
(
f
"The following column(s) already exist and would be overwritten:
{
conflict
}
"
)
missing_keys
=
[
key
for
key
in
required_keys
if
key
not
in
dataframe
.
columns
]
if
missing_keys
:
self
.
logger
.
error
(
f
"The following required columns are missing from the dataframe:
{
missing_keys
}
"
)
def
run
(
self
,
storage
:
DataFlowStorage
,
input_key
:
str
=
"generated_cot"
,
)
->
list
:
'''
Execute the answer format filter process
'''
self
.
input_key
=
input_key
dataframe
=
storage
.
read
(
"dataframe"
)
self
.
_validate_dataframe
(
dataframe
)
indexes
=
np
.
zeros
(
len
(
dataframe
)).
astype
(
int
)
for
i
,
item
in
dataframe
.
iterrows
():
answer
=
item
[
self
.
input_key
]
if
ReasoningAnswerFormatterFilter
.
is_valid_answer
(
answer
):
indexes
[
i
]
=
1
dataframe
=
dataframe
[
np
.
array
(
indexes
)
==
1
]
output_file
=
storage
.
write
(
dataframe
)
self
.
logger
.
info
(
f
"Results saved to
{
output_file
}
"
)
return
[
self
.
input_key
,]
\ No newline at end of file
dataflow/operators/reasoning/filter/reasoning_answer_groundtruth_filter.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow.utils.reasoning.AnswerExtraction
import
StringCleaner
,
UnitTextManager
,
AnswerExtractor
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow
import
get_logger
from
dataflow.core
import
OperatorABC
from
typing
import
Literal
from
math_verify
import
parse
,
verify
import
pandas
as
pd
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningAnswerGroundTruthFilter
(
OperatorABC
):
def
__init__
(
self
,
compare_method
:
Literal
[
"math_verify"
,
"exact"
]
=
"math_verify"
):
name2compare
=
{
'exact'
:
self
.
exact_compare
,
'math_verify'
:
self
.
math_verify_compare
}
self
.
compare
=
name2compare
[
compare_method
]
unit_manager
=
UnitTextManager
()
string_cleaner
=
StringCleaner
(
unit_manager
)
self
.
answer_extractor
=
AnswerExtractor
(
string_cleaner
)
self
.
logger
=
get_logger
()
def
exact_compare
(
self
,
answer
,
ground_truth
):
return
str
(
answer
)
==
str
(
ground_truth
)
def
math_verify_compare
(
self
,
answer
,
ground_truth
):
try
:
return
verify
(
parse
(
str
(
ground_truth
)),
parse
(
str
(
answer
)))
except
:
try
:
return
verify
(
parse
(
ground_truth
),
parse
(
answer
))
except
:
return
False
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于对比预测答案与标准答案的匹配度,支持精确匹配和数学验证两种方式。
\n\n
"
"输入参数:
\n
"
"- input_test_answer_key:预测答案字段名
\n
"
"- input_gt_answer_key:标准答案字段名
\n
"
"- compare_method:比较方法(exact/math_verify)
\n\n
"
"输出参数:
\n
"
"- 匹配成功返回1,否则返回0"
)
elif
lang
==
"en"
:
return
(
"This operator compares predicted answers against ground truth using exact or mathematical verification.
\n\n
"
"Input Parameters:
\n
"
"- test_answer_key: Predicted answer field
\n
"
"- gt_answer_key: Ground truth field
\n
"
"- compare_method: Comparison method (exact/math_verify)
\n\n
"
"Output Parameters:
\n
"
"- Returns 1 for matches, 0 otherwise"
)
else
:
return
"AnswerGroundTruthFilter performs answer validation"
def
run
(
self
,
storage
:
DataFlowStorage
,
input_test_answer_key
:
str
=
"generated_cot"
,
input_gt_answer_key
:
str
=
"golden_answer"
)
->
list
:
self
.
test_answer_key
=
input_test_answer_key
self
.
gt_answer_key
=
input_gt_answer_key
dataframe
=
storage
.
read
(
"dataframe"
)
output
=
[]
answers
=
dataframe
[
self
.
test_answer_key
]
ground_truths
=
dataframe
[
self
.
gt_answer_key
]
for
i
in
range
(
len
(
answers
)):
final_answer
=
self
.
answer_extractor
.
extract_answer
(
answers
[
i
],
None
)
if
self
.
compare
(
final_answer
,
ground_truths
[
i
]):
output
.
append
(
dataframe
.
iloc
[
i
])
output
=
pd
.
DataFrame
(
output
)
output_file
=
storage
.
write
(
output
)
self
.
logger
.
info
(
f
"Filtered data saved to
{
output_file
}
"
)
return
[
self
.
test_answer_key
,
self
.
gt_answer_key
]
\ No newline at end of file
dataflow/operators/reasoning/filter/reasoning_answer_model_judge_filter.py
0 → 100644
View file @
97e8278b
from
dataflow.utils.registry
import
OPERATOR_REGISTRY
from
dataflow
import
get_logger
from
dataflow.core
import
OperatorABC
from
dataflow.utils.storage
import
DataFlowStorage
from
dataflow.core
import
LLMServingABC
from
dataflow.prompts.model_evaluation.general
import
AnswerJudgePromptQuestion
,
AnswerJudgePrompt
from
dataflow.core.prompt
import
prompt_restrict
,
DIYPromptABC
import
re
import
pandas
as
pd
import
numpy
as
np
from
typing
import
Union
@
prompt_restrict
(
AnswerJudgePromptQuestion
,
AnswerJudgePrompt
,
)
@
OPERATOR_REGISTRY
.
register
()
class
ReasoningAnswerModelJudgeFilter
(
OperatorABC
):
def
__init__
(
self
,
system_prompt
:
str
=
"You are a helpful assistant specialized in evaluating answer correctness."
,
llm_serving
:
LLMServingABC
=
None
,
prompt_template
:
Union
[
AnswerJudgePromptQuestion
,
AnswerJudgePrompt
,
DIYPromptABC
]
=
AnswerJudgePromptQuestion
,
keep_all_samples
:
bool
=
False
,
# 新增参数,控制是否保留所有样本
):
self
.
logger
=
get_logger
()
if
prompt_template
is
None
:
prompt_template
=
AnswerJudgePrompt
()
self
.
prompt_template
=
prompt_template
self
.
system_prompt
=
system_prompt
self
.
llm_serving
=
llm_serving
self
.
empty_responses_count
=
0
# 添加空响应计数器
self
.
keep_all_samples
=
keep_all_samples
# 保存参数
@
staticmethod
def
get_desc
(
lang
:
str
=
"zh"
):
if
lang
==
"zh"
:
return
(
"该算子用于对答案进行正确性评判,通过比较当前答案与参考答案的语义一致性,判断答案是否正确。"
"调用大语言模型进行语义理解和判断,最终返回每个答案是否正确的二分类结果。
\n
"
"输入参数:
\n
"
"- system_prompt:系统提示词,用于定义模型行为
\n
"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口
\n
"
"- prompt_template:提示模板对象,用于构建评判提示词
\n
"
"- keep_all_samples:是否保留所有样本,默认为False(仅保留正确答案)
\n
"
"- question_key:问题字段名,默认为'question'
\n
"
"- answer_key:当前答案字段名,默认为'answer'
\n
"
"- reference_key:参考答案字段名,默认为'reference_answer'
\n
"
"输出参数:
\n
"
"- DataFrame,包含原始数据和判断结果(answer_match_result字段)
\n
"
"- 如果keep_all_samples为False,则仅保留判断结果为True的行
\n
"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif
lang
==
"en"
:
return
(
"This operator evaluates the correctness of answers by comparing the semantic consistency between "
"the current answer and the reference answer. It uses a large language model for semantic understanding "
"and judgment, ultimately returning a binary classification result for each answer.
\n
"
"Input Parameters:
\n
"
"- system_prompt: System prompt to define model behavior
\n
"
"- llm_serving: LLM serving object implementing LLMServingABC interface
\n
"
"- prompt_template: Prompt template object for constructing evaluation prompts
\n
"
"- keep_all_samples: Whether to keep all samples, default is False (only keep correct answers)
\n
"
"- question_key: Field name for questions, default is 'question'
\n
"
"- answer_key: Field name for current answers, default is 'answer'
\n
"
"- reference_key: Field name for reference answers, default is 'reference_answer'
\n\n
"
"Output Parameters:
\n
"
"- DataFrame containing original data and judgment results (answer_match_result field)
\n
"
"- If keep_all_samples is False, only rows with True judgment results are retained
\n
"
"- List containing input field names for subsequent operator reference"
)
else
:
return
(
"AnswerJudge evaluates answer correctness by comparing semantic consistency with reference answers using LLM."
)
def
ResolveResponse
(
self
,
response
):
# 检查空响应
if
response
is
None
or
(
isinstance
(
response
,
str
)
and
response
.
strip
()
==
''
):
self
.
empty_responses_count
+=
1
return
False
try
:
pattern
=
re
.
compile
(
r
'"judgement_result"\s*:\s*(true|false)'
,
re
.
IGNORECASE
)
match
=
pattern
.
search
(
response
)
result_value
=
None
if
match
:
result_value
=
match
.
group
(
1
).
lower
()
else
:
# 备用解析逻辑,检查响应中是否包含true或false
if
"true"
in
response
.
lower
():
result_value
=
"true"
else
:
result_value
=
"false"
if
result_value
==
"true"
:
return
True
else
:
return
False
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Response format error:
{
response
}
. Error:
{
e
}
"
)
return
False
def
run
(
self
,
storage
:
DataFlowStorage
,
input_question_key
:
str
=
"question"
,
input_answer_key
:
str
=
"answer"
,
input_reference_key
:
str
=
"reference_answer"
)
->
list
:
self
.
question_key
=
input_question_key
self
.
answer_key
=
input_answer_key
self
.
reference_key
=
input_reference_key
dataframe
=
storage
.
read
(
"dataframe"
)
# 检查必要的列是否存在
required_columns
=
[
input_question_key
,
input_answer_key
,
input_reference_key
]
for
column
in
required_columns
:
if
column
not
in
dataframe
.
columns
:
self
.
logger
.
error
(
f
"Required column '
{
column
}
' not found in dataframe"
)
return
required_columns
# 检查参考答案是否为空或不存在
empty_reference_mask
=
dataframe
[
input_reference_key
].
isna
()
|
(
dataframe
[
input_reference_key
]
==
''
)
skipped_rows
=
dataframe
[
empty_reference_mask
]
valid_rows
=
dataframe
[
~
empty_reference_mask
]
# 记录跳过的行数
skipped_count
=
len
(
skipped_rows
)
# 初始化结果列,默认所有行为False
dataframe
[
'answer_match_result'
]
=
False
if
len
(
valid_rows
)
==
0
:
self
.
logger
.
warning
(
"No valid samples with reference answers found. All samples skipped."
)
if
self
.
keep_all_samples
:
output_file
=
storage
.
write
(
dataframe
)
# 保留所有行,但answer_match_result都为False
else
:
output_file
=
storage
.
write
(
pd
.
DataFrame
(
columns
=
dataframe
.
columns
))
# 不保留任何行
self
.
logger
.
info
(
f
"Dataframe saved to
{
output_file
}
. Skipped
{
skipped_count
}
samples due to missing reference answers."
)
return
required_columns
+
[
'answer_match_result'
]
# 只对有参考答案的行构建提示词并调用LLM
inputs
=
[
self
.
prompt_template
.
build_prompt
(
question
=
row
[
input_question_key
],
answer
=
row
[
input_answer_key
],
reference_answer
=
row
[
input_reference_key
]
)
for
_
,
row
in
valid_rows
.
iterrows
()]
responses
=
self
.
llm_serving
.
generate_from_input
(
user_inputs
=
inputs
,
system_prompt
=
self
.
system_prompt
)
results
=
[
self
.
ResolveResponse
(
response
)
for
response
in
responses
]
# 创建结果掩码,与valid_rows长度相同
result_mask
=
np
.
array
(
results
,
dtype
=
bool
)
# 更新有效行的answer_match_result
valid_indices
=
valid_rows
.
index
for
i
,
idx
in
enumerate
(
valid_indices
):
dataframe
.
at
[
idx
,
'answer_match_result'
]
=
results
[
i
]
# 根据keep_all_samples决定是否保留所有样本
if
self
.
keep_all_samples
:
# 保留所有样本,包括不匹配的和没有参考答案的
final_dataframe
=
dataframe
else
:
# 只保留匹配的样本
final_dataframe
=
dataframe
[
dataframe
[
'answer_match_result'
]
==
True
]
output_file
=
storage
.
write
(
final_dataframe
)
# 记录统计信息
total_samples
=
len
(
dataframe
)
valid_samples
=
len
(
valid_rows
)
matched_samples
=
sum
(
results
)
accuracy
=
matched_samples
/
valid_samples
if
valid_samples
>
0
else
0
self
.
logger
.
info
(
f
"Processed answers saved to
{
output_file
}
."
)
self
.
logger
.
info
(
f
"Total samples:
{
total_samples
}
, Valid samples:
{
valid_samples
}
, Skipped samples:
{
skipped_count
}
"
)
self
.
logger
.
info
(
f
"Matched answers:
{
matched_samples
}
, Accuracy:
{
accuracy
:.
2
%
}
"
)
self
.
logger
.
info
(
f
"Output samples:
{
len
(
final_dataframe
)
}
"
)
# 记录空响应数量并重置计数器
if
self
.
empty_responses_count
>
0
:
self
.
logger
.
error
(
f
"Found
{
self
.
empty_responses_count
}
empty responses during evaluation."
)
self
.
empty_responses_count
=
0
return
required_columns
+
[
'answer_match_result'
]
\ No newline at end of file
Prev
1
…
7
8
9
10
11
12
13
14
15
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment