Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
import os
from pathlib import Path
from trafilatura import fetch_url, extract
from urllib.parse import urlparse
from tqdm import tqdm
import requests
def is_url(string):
try:
result = urlparse(string)
return all([result.scheme, result.netloc])
except ValueError:
return False
def _parse_file_with_mineru(raw_file: str, output_file: str, mineru_backend: str = "vlm-vllm-engine") -> str:
"""
Uses MinerU to parse PDF/image files (pdf/png/jpg/jpeg/webp/gif) into Markdown files.
Internally, the parsed outputs for each item are stored in a structured directory:
'intermediate_dir/pdf_name/MinerU_Version[mineru_backend]'.
This directory stores various MinerU parsing outputs, and you can customize
which content to extract based on your needs.
Args:
raw_file: Input file path, supports .pdf/.png/.jpg/.jpeg/.webp/.gif
output_file: Full path for the output Markdown file
mineru_backend: Sets the backend engine for MinerU. Options include:
- "pipeline": Traditional pipeline processing (MinerU1)
- "vlm-sglang-engine": New engine based on multimodal language models (MinerU2) (default recommended)
Choose the appropriate backend based on your needs. Defaults to "vlm-sglang-engine".
For more details, refer to the MinerU GitHub: https://github.com/opendatalab/MinerU.
Returns:
output_file: Path to the Markdown file
"""
try:
import mineru
except ImportError:
raise Exception(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
logger=get_logger()
os.environ['MINERU_MODEL_SOURCE'] = "local" # 可选:从本地加载模型
# pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client
MinerU_Version = {"pipeline": "auto", "vlm-transformers": "vlm", 'vlm-vllm-engine': 'vlm', 'vlm-http-client': 'vlm'}
raw_file = Path(raw_file)
# import pdb; pdb.set_trace()
pdf_name = Path(raw_file).stem
intermediate_dir = output_file
intermediate_dir = os.path.join(intermediate_dir, "mineru")
import subprocess
command = [
"mineru",
"-p", raw_file,
"-o", intermediate_dir,
"-b", mineru_backend,
"--source", "local"
]
try:
result = subprocess.run(
command,
#stdout=subprocess.DEVNULL,
#stderr=subprocess.DEVNULL,
check=True
)
except Exception as e:
raise RuntimeError(f"Failed to process file with MinerU: {str(e)}")
# Directory for storing raw data, including various MinerU parsing outputs.
# You can customize which content to extract based on your needs.
PerItemDir = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend])
output_file = os.path.join(PerItemDir, f"{pdf_name}.md")
logger.info(f"Markdown saved to: {output_file}")
return output_file
def _parse_xml_to_md(raw_file:str=None, url:str=None, output_file:str=None):
logger=get_logger()
if(url):
downloaded=fetch_url(url)
if not downloaded:
downloaded = "fail to fetch this url. Please check your Internet Connection or URL correctness"
with open(output_file,"w", encoding="utf-8") as f:
f.write(downloaded)
return output_file
elif(raw_file):
with open(raw_file, "r", encoding='utf-8') as f:
downloaded=f.read()
else:
raise Exception("Please provide at least one of file path and url string.")
try:
result=extract(downloaded, output_format="markdown", with_metadata=True)
logger.info(f"Extracted content is written into {output_file}")
with open(output_file,"w", encoding="utf-8") as f:
f.write(result)
except Exception as e:
logger.error("Error during extract this file or link: ", e)
return output_file
def is_pdf_url(url):
try:
# 发送HEAD请求,只获取响应头,不下载文件
response = requests.head(url, allow_redirects=True)
# 如果响应的Content-Type是application/pdf
if response.status_code == 200 and response.headers.get('Content-Type') == 'application/pdf':
return True
else:
print(f"Content-Type: {response.headers.get('Content-Type')}")
return False
except requests.exceptions.RequestException:
# 如果请求失败,返回False
print("Request failed")
return False
def download_pdf(url, save_path):
try:
# 发送GET请求下载PDF文件
response = requests.get(url, stream=True)
# 确保响应内容是PDF
if response.status_code == 200 and response.headers.get('Content-Type') == 'application/pdf':
# 将PDF保存到本地
pdf_folder = os.path.dirname(save_path)
os.makedirs(pdf_folder, exist_ok=True)
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print(f"PDF saved to {save_path}")
else:
print("The URL did not return a valid PDF file.")
except requests.exceptions.RequestException as e:
print(f"Error downloading PDF: {e}")
@OPERATOR_REGISTRY.register()
class FileOrURLToMarkdownConverterBatch(OperatorABC):
"""
mineru_backend sets the backend engine for MinerU. Options include:
- "pipeline": Traditional pipeline processing (MinerU1)
- "vlm-sglang-engine": New engine based on multimodal language models (MinerU2) (default recommended)
Choose the appropriate backend based on your needs. Defaults to "vlm-sglang-engine".
For more details, refer to the MinerU GitHub: https://github.com/opendatalab/MinerU.
"""
def __init__(self, intermediate_dir: str = "intermediate", lang: str = "en", mineru_backend: str = "vlm-sglang-engine"):
self.logger = get_logger()
self.intermediate_dir=intermediate_dir
os.makedirs(self.intermediate_dir, exist_ok=True)
self.lang=lang
self.mineru_backend = mineru_backend
@staticmethod
def get_desc(lang: str = "zh"):
"""
返回算子功能描述 (根据run()函数的功能实现)
"""
if lang == "zh":
return (
"知识提取算子:支持从多种文件格式中提取结构化内容并转换为标准Markdown\n"
"核心功能:\n"
"1. PDF文件:使用MinerU解析引擎提取文本/表格/公式,保留原始布局\n"
"2. Office文档(DOC/PPT等):通过DocConverter转换为Markdown格式\n"
"3. 网页内容(HTML/XML):使用trafilatura提取正文并转为Markdown\n"
"4. 纯文本(TXT/MD):直接透传不做处理\n"
"特殊处理:\n"
"- 自动识别中英文文档(lang参数)\n"
"- 支持本地文件路径和URL输入\n"
"- 生成中间文件到指定目录(intermediate_dir)"
)
else: # 默认英文
return (
"Knowledge Extractor: Converts multiple file formats to structured Markdown\n"
"Key Features:\n"
"1. PDF: Uses MinerU engine to extract text/tables/formulas with layout preservation\n"
"2. Office(DOC/PPT): Converts to Markdown via DocConverter\n"
"3. Web(HTML/XML): Extracts main content using trafilatura\n"
"4. Plaintext(TXT/MD): Directly passes through without conversion\n"
"Special Handling:\n"
"- Auto-detects Chinese/English documents(lang param)\n"
"- Supports both local files and URLs\n"
"- Generates intermediate files to specified directory(intermediate_dir)"
)
def run(self, storage: DataFlowStorage, input_key: str = "source", output_key: str = "text_path"):
self.logger.info("Starting content extraction...")
self.logger.info("If the input is a URL or a large file, this process might take some time. Please wait...")
dataframe = storage.read("dataframe")
self.logger.info(f"Loaded dataframe with {len(dataframe)} entries.")
output_file_all = []
# Wrap iterrows with tqdm for progress tracking
for index, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="FileOrURLToMarkdownConverter Processing files", ncols=100):
content = row.get(input_key, "")
if is_url(content):
# Case: Input is a URL
if is_pdf_url(content):
pdf_save_path = os.path.join(
os.path.dirname(storage.first_entry_file_name),
f"raw/crawled/crawled_{index}.pdf"
)
self.logger.info(f"Downloading PDF from {content} to {pdf_save_path}")
download_pdf(content, pdf_save_path)
content = pdf_save_path
self.logger.info(f"pdf file has been fetched and saved to {pdf_save_path}")
else:
output_file = os.path.join(
os.path.dirname(storage.first_entry_file_name),
f"raw/crawled/crawled_{index}.md"
)
os.makedirs(os.path.dirname(output_file),exist_ok=True)
output_file = _parse_xml_to_md(url=content, output_file=output_file)
self.logger.info(f"Primary extracted result written to: {output_file}")
output_file_all.append(output_file)
continue
# Extract file name and extension
raw_file = content
raw_file_name = os.path.splitext(os.path.basename(raw_file))[0]
raw_file_suffix = os.path.splitext(raw_file)[1].lower()
raw_file_suffix_no_dot = raw_file_suffix.lstrip(".")
# Define default output path
output_file = os.path.join(
self.intermediate_dir,
f"{raw_file_name}_{raw_file_suffix_no_dot}.md"
)
# Case: Local file path
if not os.path.exists(content):
self.logger.error(f"File not found: Path {content} does not exist.")
output_file_all.append("")
continue
_, ext = os.path.splitext(content)
ext = ext.lower()
if ext in [".pdf", ".png", ".jpg", ".jpeg", ".webp", ".gif"]:
self.logger.info(f"Using MinerU backend: {self.mineru_backend}")
output_file = _parse_file_with_mineru(
raw_file=content,
output_file=self.intermediate_dir,
mineru_backend=self.mineru_backend
)
elif ext in [".html", ".xml"]:
output_file = _parse_xml_to_md(raw_file=content, output_file=output_file)
elif ext in [".txt", ".md"]:
output_file = content # No parsing needed for plain text or Markdown files
else:
self.logger.error(f"Unsupported file type: {ext} for file {content}")
output_file = ""
output_file_all.append(output_file)
# Save results back to storage
dataframe[output_key] = output_file_all
output_file_path = storage.write(dataframe)
self.logger.info(f"Final extraction results saved to: {output_file_path}")
return output_file_path
import os
import json
from typing import Dict, List, Optional
from chonkie import (
TokenChunker,
SentenceChunker,
SemanticChunker,
RecursiveChunker
)
from tokenizers import Tokenizer
from transformers import AutoTokenizer
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
@OPERATOR_REGISTRY.register()
class KBCChunkGenerator(OperatorABC):
def __init__(self,
chunk_size: int = 512,
chunk_overlap: int = 50,
split_method: str = "token",
min_tokens_per_chunk: int = 100,
tokenizer_name: str = "bert-base-uncased",
):
# 必需参数检查
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.split_method = split_method
self.min_tokens_per_chunk = min_tokens_per_chunk
tokenizer_name = tokenizer_name
# 初始化tokenizer和chunker
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.chunker = self._initialize_chunker()
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "zh"):
if(lang=="zh"):
return (
"CorpusTextSplitter是轻量级文本分割工具,",
"支持词/句/语义/递归分块,",
"可配置块大小、重叠和最小块长度",
)
elif(lang=="en"):
return (
"CorpusTextSplitter is a lightweight text segmentation tool",
"that supports multiple chunking methods",
"(token/sentence/semantic/recursive) with configurable size and overlap,",
"optimized for RAG applications."
)
def _initialize_chunker(self):
"""Initialize the appropriate chunker based on method"""
if self.split_method == "token":
return TokenChunker(
tokenizer=self.tokenizer,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
elif self.split_method == "sentence":
return SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
elif self.split_method == "semantic":
return SemanticChunker(
chunk_size=self.chunk_size,
)
elif self.split_method == "recursive":
return RecursiveChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
else:
raise ValueError(f"Unsupported split method: {self.split_method}")
def _load_text(self, file_path) -> str:
"""Load text from input file"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Input file not found: {file_path}")
if file_path.endswith('.txt') or file_path.endswith('.md') or file_path.endswith('.xml'):
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif file_path.endswith(('.json', '.jsonl')):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f) if file_path.endswith('.json') else [json.loads(line) for line in f]
text_fields = ['text', 'content', 'body']
for field in text_fields:
if isinstance(data, list) and len(data) > 0 and field in data[0]:
return "\n".join([item[field] for item in data])
elif isinstance(data, dict) and field in data:
return data[field]
raise ValueError("No text field found in JSON input")
else:
raise ValueError("Unsupported file format")
def _validate_dataframe(self, dataframe: pd.DataFrame):
forbidden_keys = [self.output_key]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def run(self, storage: DataFlowStorage, input_key:str='text_path', output_key:str="raw_chunk"):
"""Perform text splitting and save results"""
# try:
self.input_key = input_key
self.output_key = output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
text_paths = dataframe[self.input_key].tolist()
for input_path in text_paths:
if not input_path or not os.path.exists(input_path):
self.logger.error(f"无效的输入文件路径: {input_path}")
new_records = []
for row_dict, text_path in zip(dataframe.to_dict(orient='records'), text_paths):
text = self._load_text(text_path)
if(text):
# 计算总token数和最大限制
tokens = self.tokenizer.encode(text)
total_tokens = len(tokens)
max_tokens = self.tokenizer.model_max_length # 假设这是tokenizer的最大token限制
print("max_tokens: ", self.tokenizer.model_max_length)
if total_tokens <= max_tokens:
chunks = self.chunker(text)
else:
# 计算需要分割的份数x(向上取整)
x = (total_tokens + max_tokens - 1) // max_tokens
# 按词数等分文本(近似分割)
words = text.split() # 按空格分词
words_per_chunk = (len(words) + x - 1) // x # 每份的词数
chunks = []
for i in range(0, len(words), words_per_chunk):
chunk_text = ' '.join(words[i:i+words_per_chunk])
chunks.extend(self.chunker(chunk_text))
# 每个chunk生成一条记录
for chunk in chunks:
new_row = row_dict.copy() # 保留原行里所有字段(不会改动原 dataframe 的其他 key)
new_row[self.output_key] = chunk.text # 新增/覆盖 output_key 字段
new_records.append(new_row)
new_df = pd.DataFrame(new_records)
output_file = storage.write(new_df)
self.logger.info(f"Successfully split text for {len(text_paths)} files. Saved to {output_file}")
return [output_key]
\ No newline at end of file
import os
import json
from typing import Dict, List, Optional
from chonkie import (
TokenChunker,
SentenceChunker,
SemanticChunker,
RecursiveChunker
)
from tokenizers import Tokenizer
from transformers import AutoTokenizer
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
@OPERATOR_REGISTRY.register()
class KBCChunkGeneratorBatch(OperatorABC):
def __init__(self,
chunk_size: int = 512,
chunk_overlap: int = 50,
split_method: str = "token",
min_tokens_per_chunk: int = 100,
tokenizer_name: str = "bert-base-uncased",
):
# 必需参数检查
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.split_method = split_method
self.min_tokens_per_chunk = min_tokens_per_chunk
tokenizer_name = tokenizer_name
# 初始化tokenizer和chunker
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.chunker = self._initialize_chunker()
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "zh"):
if (lang == "zh"):
return (
"CorpusTextSplitter是轻量级文本分割工具,",
"支持词/句/语义/递归分块,",
"可配置块大小、重叠和最小块长度",
)
elif (lang == "en"):
return (
"CorpusTextSplitter is a lightweight text segmentation tool",
"that supports multiple chunking methods",
"(token/sentence/semantic/recursive) with configurable size and overlap,",
"optimized for RAG applications."
)
def _initialize_chunker(self):
"""Initialize the appropriate chunker based on method"""
if self.split_method == "token":
return TokenChunker(
tokenizer=self.tokenizer,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
elif self.split_method == "sentence":
return SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
elif self.split_method == "semantic":
return SemanticChunker(
chunk_size=self.chunk_size,
)
elif self.split_method == "recursive":
return RecursiveChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
else:
raise ValueError(f"Unsupported split method: {self.split_method}")
def _load_text(self, text_paths: List[str]) -> List[str]:
"""Load text from file list"""
texts = []
for text_path in text_paths:
if not os.path.exists(text_path):
self.logger.error(f"Input file not found: {text_path}")
texts.append("")
elif text_path.endswith('.txt') or text_path.endswith('.md') or text_path.endswith('.xml'):
with open(text_path, 'r', encoding='utf-8') as f:
texts.append(f.read())
elif text_path.endswith(('.json', '.jsonl')):
with open(text_path, 'r', encoding='utf-8') as f:
data = json.load(f) if text_path.endswith(
'.json') else [json.loads(line) for line in f]
text_fields = ['text', 'content', 'body']
for field in text_fields:
if isinstance(data, list) and len(data) > 0 and field in data[0]:
texts.append("\n".join([item[field] for item in data]))
elif isinstance(data, dict) and field in data:
texts.append(data[field])
if (field not in text_fields):
raise ValueError("No text field found in JSON input")
else:
raise ValueError(f"Unsupported file format for {text_path}")
return texts
def _validate_dataframe(self, dataframe: pd.DataFrame):
forbidden_keys = [self.output_key]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if conflict:
raise ValueError(
f"The following column(s) already exist and would be overwritten: {conflict}")
def run(self, storage: DataFlowStorage, input_key: str = "text_path", output_key: str = "chunk_path"):
"""Perform text splitting and save results"""
# try:
self.input_key = input_key
self.output_key = output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
text_paths = dataframe[self.input_key].tolist()
texts = self._load_text(text_paths)
output_paths = []
chunks = []
for i, text in enumerate(texts):
if(text):
# 计算总token数和最大限制
tokens = self.tokenizer.encode(text)
total_tokens = len(tokens)
max_tokens = self.tokenizer.model_max_length # 假设这是tokenizer的最大token限制
print("max_tokens: ", self.tokenizer.model_max_length)
if total_tokens <= max_tokens:
chunks = self.chunker(text)
else:
# 计算需要分割的份数x(向上取整)
x = (total_tokens + max_tokens - 1) // max_tokens
# 按词数等分文本(近似分割)
words = text.split() # 按空格分词
words_per_chunk = (len(words) + x - 1) // x # 每份的词数
chunks = []
for j in range(0, len(words), words_per_chunk):
chunk_text = ' '.join(words[j:j+words_per_chunk])
chunks.extend(self.chunker(chunk_text))
json_chunks = [{
"raw_chunk": chunk.text,
} for chunk in chunks]
output_dir = "/".join([os.path.dirname(text_paths[i]), "extract"])
os.makedirs(output_dir, exist_ok=True)
file_name = os.path.splitext(os.path.basename(text_paths[i]))[0]+'_chunk.json'
output_path = os.path.join(output_dir, file_name)
output_paths.append(output_path)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_chunks, f, ensure_ascii=False, indent=4)
self.logger.info(
f"Successfully split {text_paths[i]} into {len(chunks)} chunks. Saved to {output_path}")
else:
output_paths.append("")
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>")
print(output_paths)
print("<<<<<<<<<<<<<<<<<<<<<<<<<<<")
dataframe[self.output_key] = output_paths
output_file = storage.write(dataframe)
self.logger.info(
f"Successfully split text into {len(chunks)} chunks. Saved to {output_file}")
return [output_key]
from dataflow.prompts.text2qa import Text2MultiHopQAGeneratorPrompt
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
import random
from typing import Any, Dict, List, Optional, Sequence
import json
from tqdm import tqdm
import re
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from typing import Union
@prompt_restrict(
Text2MultiHopQAGeneratorPrompt
)
@OPERATOR_REGISTRY.register()
class KBCMultiHopQAGeneratorBatch(OperatorABC):
r"""A processor for generating multi-hop question-answer pairs from user
data.
This class handles the processing of text data to generate multi-hop
question-answer pairs using either an AI model or rule-based approaches.
It manages the entire pipeline from text preprocessing to dataset curation.
"""
def __init__(self,
llm_serving: LLMServingABC,
seed: int = 0,
lang="en",
prompt_template: Union[Text2MultiHopQAGeneratorPrompt, DIYPromptABC] = None
):
r"""Initialize the UserDataProcessor.
Args:
config (Optional[ProcessorConfig], optional): Configuration for
data processing. (default: :obj:`None`)
"""
self.rng = random.Random(seed)
self.llm_serving = llm_serving
self.lang = lang
self.logger = get_logger()
if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = Text2MultiHopQAGeneratorPrompt()
@staticmethod
def get_desc(lang: str = "zh") -> tuple:
"""Returns a description of the processor's functionality.
Args:
lang (str, optional): Language for description ('zh' or 'en').
Returns:
tuple: Description strings in specified language, including format example
"""
if lang == "zh":
return (
"MultiHopQAGenerator 是多跳问答对生成处理器,支持从文本中自动生成需要多步推理的问题与答案。",
"处理流程包括:文本预处理、信息抽取、问题生成与回答生成,支持自定义语言模型后端和参数。",
"输出格式如下:",
"输入:\n"
"text: <原始上下文文本>",
"输出:\n"
"{\n"
" \"text\": <处理后的文本字符串>,\n"
" \"qa_pairs\": [\n"
" {\n"
" \"question\": <字符串:生成的问题>,\n"
" \"reasoning_steps\": [\n"
" {\"step\": <推理过程的步骤 1>},\n"
" {\"step\": <步骤 2>} ...\n"
" ],\n"
" \"answer\": <字符串:最终答案>,\n"
" \"supporting_facts\": [<支持该答案的事实 1>, <事实 2>, ...],\n"
" \"type\": <可选:问题类型,如“生物学”、“历史”等>\n"
" },\n"
" ...\n"
" ],\n"
" \"metadata\": {\n"
" \"source\": <数据来源>,\n"
" \"timestamp\": <时间戳字符串>,\n"
" \"complexity\": <整数:问题复杂度标记>\n"
" }\n"
"}"
)
else:
return (
"MultiHopQAGenerator is a processor for generating multi-hop question-answer pairs from raw text.",
"It includes preprocessing, information extraction, and reasoning-based QA generation, with configurable LLM backends.",
"Expected output format:",
"Input:\n"
"text: <raw input context>",
"Output:\n"
"{\n"
" \"text\": <processed input text>,\n"
" \"qa_pairs\": [\n"
" {\n"
" \"question\": <string: generated question>,\n"
" \"reasoning_steps\": [\n"
" {\"step\": <inference step 1>},\n"
" {\"step\": <inference step 2>} ...\n"
" ],\n"
" \"answer\": <string: final answer>,\n"
" \"supporting_facts\": [<fact 1>, <fact 2>, ...],\n"
" \"type\": <optional string: QA category>\n"
" },\n"
" ...\n"
" ],\n"
" \"metadata\": {\n"
" \"source\": <source string>,\n"
" \"timestamp\": <timestamp string>,\n"
" \"complexity\": <integer: reasoning complexity>\n"
" }\n"
"}"
)
def process_text(
self, text: str, source: str = "user_input"
) -> List[Dict[str, Any]]:
r"""Process a single text to generate multi-hop QA pairs.
Args:
text (str): The input text to process.
source (str, optional): Source identifier for the text.
(default: :obj:`"user_input"`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
"""
# Convert text to standard format
raw_data = [
{
'text': text,
'source': source,
}
]
# Construct examples
constructor = ExampleConstructor(
lang=self.lang,
llm_serving=self.llm_serving,
prompt_template = self.prompt_template
)
examples = constructor.construct_examples(raw_data)
# Manage data
# curator = DataCurator(self.config, self.rng)
# final_dataset = curator.curate_dataset(examples)
return examples
def process_batch(
self, texts: List[str], sources: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
r"""Process multiple texts in batch to generate multi-hop QA pairs.
Args:
texts (List[str]): List of input texts to process.
sources (Optional[List[str]], optional): List of source
identifiers. (default: :obj:`None`)
Returns:
List[Dict[str, Any]]: List of processed examples with QA pairs and
metadata.
Raises:
ValueError: If length of sources doesn't match length of texts.
"""
if sources is None:
sources = ["default_source"] * len(texts)
elif len(sources) != len(texts):
raise ValueError("Length of sources must match length of texts")
raw_data = [
{
'text': text,
'source': source,
}
for text, source in zip(texts, sources)
]
# Construct examples
constructor = ExampleConstructor(
lang=self.lang, llm_serving=self.llm_serving, )
examples = constructor.construct_examples(raw_data)
# # Manage data
# curator = DataCurator(self.config, self.rng)
# final_dataset = curator.curate_dataset(examples)
return examples
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(
f"The following column(s) already exist and would be overwritten: {conflict}")
def run(
self,
storage: DataFlowStorage = None,
input_key: str = 'chunk_path',
output_key: str = 'enhanced_chunk_path',
):
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
chunk_paths = dataframe[self.input_key].tolist()
for chunk_path in chunk_paths:
if(chunk_path):
texts = []
if str(chunk_path).endswith(".json"):
with open(chunk_path, "r", encoding="utf-8") as f:
data = json.load(f)
texts = [item["cleaned_chunk"] for item in data]
elif str(chunk_path).endswith(".jsonl"):
with open(chunk_path, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f]
texts = [item["cleaned_chunk"] for item in data]
else:
print(f"Unsupported file format: {chunk_path}")
continue
# 生成 QA 对
qa_pairs_batch = self.process_batch(texts)
# 写入到原数据中
for item, qa_pairs in zip(data, qa_pairs_batch):
item["qa_pairs"] = qa_pairs
# 回写到原始文件中(覆盖写入)
with open(chunk_path, "w", encoding="utf-8") as f:
if str(chunk_path).endswith(".json"):
json.dump(data, f, ensure_ascii=False, indent=4)
else: # jsonl
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
self.logger.info(f"constructed {len(qa_pairs)} multihop QA for {chunk_path}")
dataframe[self.output_key] = chunk_paths
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [output_key]
class ExampleConstructor:
r"""Constructs training examples from raw text data.
This class handles the construction of training examples by preprocessing
text, extracting information pairs, and generating question-answer pairs.
"""
def __init__(
self,
lang: str = "en",
llm_serving: LLMServingABC = None,
min_text_length: int = 100,
max_text_length: int = 200000,
prompt_template = None
):
r"""Initialize the ExampleConstructor.
Args:
config (ProcessorConfig): Configuration for example construction.
multi_hop_agent (Optional[MultiHopGeneratorAgent], optional):
Agent for generating multi-hop QA pairs. (default: :obj:`None`)
"""
self.lang = lang
self.llm_sering = llm_serving
self.logger = get_logger()
self.max_length = max_text_length
self.min_length = min_text_length
# self.prompt = Text2MultiHopQAGeneratorPrompt(lang=self.lang)
if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = Text2MultiHopQAGeneratorPrompt()
def construct_examples(
self, raw_data: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
r"""Construct training examples from raw data.
Args:
raw_data (List[Dict[str, Any]]): List of raw data dictionaries
containing text and metadata.
Returns:
List[Dict[str, Any]]: List of constructed examples with QA pairs
and metadata.
"""
self.logger.info("Starting to construct examples...")
examples = []
for data in tqdm(raw_data, desc="Constructing examples"):
# 1. Text preprocessing
processed_text = self._preprocess_text(data.get('text', ''))
if not processed_text:
example = {
# 'text': processed_text,
'qa_pairs': [],
'metadata': {
'source': data.get('source', 'unknown'),
'timestamp': data.get('timestamp', ''),
'complexity': 0,
},
}
examples.append(example)
continue
# 2. Generate key information pairs
info_pairs = self._extract_info_pairs(processed_text)
# 3. Construct question-answer pairs
if(info_pairs):
qa_pairs = self._generate_qa_pairs(info_pairs)
else:
qa_pairs = []
# 4. Add metadata
example = {
# 'text': processed_text,
'qa_pairs': qa_pairs,
'metadata': {
'source': data.get('source', 'unknown'),
'timestamp': data.get('timestamp', ''),
'complexity': self._calculate_complexity(qa_pairs) if qa_pairs else 0,
},
}
examples.append(example)
# self.logger.info(f"Successfully constructed {len(examples)} examples")
return examples
def _preprocess_text(self, text: str) -> str:
r"""Preprocess input text for example construction.
Args:
text (str): Input text to preprocess.
Returns:
str: Preprocessed text, or empty string if text fails quality
checks.
"""
if not isinstance(text, str):
return ''
# 1. Basic cleaning
text = text.strip()
# 2. Length check
if (
len(text) < self.min_length
or len(text) > self.max_length
):
self.logger.warning("text fail to pass length check.")
return ''
# 3. Quality check
if not self._check_text_quality(text):
self.logger.warning("text fail to pass quality check.")
return ''
return text
def _calculate_special_char_ratio(self, text):
# 中文字符的Unicode范围(基本汉字+扩展)
chinese_ranges = [
(0x4E00, 0x9FFF), # 基本汉字
(0x3400, 0x4DBF), # 扩展A
(0x20000, 0x2A6DF), # 扩展B
(0x2A700, 0x2B73F), # 扩展C
(0x2B740, 0x2B81F), # 扩展D
(0x2B820, 0x2CEAF) # 扩展E
]
special_count = 0
for c in text:
# 检查是否为中文、字母数字或空格
is_chinese = any(start <= ord(c) <= end for start,
end in chinese_ranges)
if not (c.isalnum() or c.isspace() or is_chinese):
special_count += 1
return special_count / len(text) if text else 0
def _check_text_quality(self, text: str) -> bool:
r"""Check the quality of input text.
Args:
text (str): Text to check quality for.
Returns:
bool: True if text passes quality checks, False otherwise.
"""
# 1. Basic quality check
if (self.lang == "en" and text.count('.') < 2): # Must have at least 2 sentences
return False
elif (self.lang in ["zh", "ch"] and text.count("。") < 2):
return False
# 2. Special character ratio check
special_char_ratio = self._calculate_special_char_ratio(text)
if special_char_ratio > 0.3: # No more than 30% special characters
return False
return True
def _extract_info_pairs(self, text: str) -> List[Dict[str, Sequence[str]]]:
r"""Extract information pairs and relationships from text.
Args:
text (str): Input text to extract information from.
Returns:
List[Dict[str, Sequence[str]]]: List of dictionaries containing
premise, intermediate, conclusion, and related contexts.
"""
# Split into sentences
if (self.lang == "en"):
sentences = [s.strip() for s in text.split('.') if s.strip()]
else:
sentences = [s.strip() for s in text.split('。') if s.strip()]
info_pairs = []
# Extract combinations of multiple related sentences
for i in range(len(sentences) - 2):
if len(sentences[i]) > 10 and len(sentences[i + 1]) > 10:
info_pairs.append(
{
'premise': sentences[i],
'intermediate': sentences[i + 1],
'conclusion': sentences[i + 2]
if i + 2 < len(sentences)
else '',
'related_contexts': [
s
for j, s in enumerate(sentences)
if j != i and j != i + 1 and len(s) > 10
][:2],
# Limit to 2 additional related contexts
}
)
return info_pairs
def _generate_qa_pairs(
self, info_pairs: List[Dict[str, Sequence[str]]]
) -> List[Dict[str, str]]:
r"""Generate multi-hop question-answer pairs from information pairs.
Args:
info_pairs (List[Dict[str, Sequence[str]]]): List of information
pairs extracted from text.
Returns:
List[Dict[str, str]]: List of generated QA pairs.
"""
user_inputs = []
for pair in info_pairs:
# 1. Generate multi-hop question-answer pair using AI
# Construct full context
context = (
f"{pair['premise']}. {pair['intermediate']}."
f" {pair['conclusion']}"
)
user_inputs.append(
self.prompt_template.build_prompt(context))
sys_prompt = self.prompt_template.build_system_prompt()
responses = self.llm_sering.generate_from_input(
user_inputs=user_inputs, system_prompt=sys_prompt)
qa_pairs = self._extract_qa_pairs(responses)
return qa_pairs
def _extract_qa_pairs(self, responses: List[str]) -> List[Dict[str, Any]]:
"""
从原始响应中精确提取符合结构的QA对
自动跳过非法JSON和干扰文本
"""
qa_pairs = []
for response in responses:
# self.logger.info(f"generated qa: {response}")
# 方法1:尝试直接解析整个响应为JSON
try:
qa_pair = json.loads(response)
if isinstance(qa_pair, dict) and "question" in qa_pair:
qa_pairs.append(qa_pair)
continue
elif isinstance(qa_pair, list):
for item in qa_pair:
if isinstance(item, dict) and "question" in item:
qa_pairs.append(item)
continue
except json.JSONDecodeError:
pass
# 方法2:使用正则表达式查找所有JSON对象
try:
# 查找所有以 { 开始的JSON对象
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
# 更精确的模式,匹配完整的JSON对象
brace_count = 0
start_pos = -1
json_objects = []
for i, char in enumerate(response):
if char == '{':
if brace_count == 0:
start_pos = i
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0 and start_pos != -1:
json_str = response[start_pos:i+1]
json_objects.append(json_str)
start_pos = -1
# 尝试解析找到的每个JSON字符串
for json_str in json_objects:
try:
qa_pair = json.loads(json_str)
if (isinstance(qa_pair, dict) and
"question" in qa_pair and
"reasoning_steps" in qa_pair and
"answer" in qa_pair and
"supporting_facts" in qa_pair and
"type" in qa_pair):
qa_pairs.append(qa_pair)
# self.logger.info(
# f"Successfully extracted QA pair: {qa_pair['question']}")
except json.JSONDecodeError as e:
self.logger.debug(
f"Failed to parse JSON object: {json_str[:100]}... Error: {e}")
continue
# 对qa_pairs中重复的question进行去重
if qa_pairs:
seen_questions = set()
unique_qa_pairs = []
for qa_pair in qa_pairs:
question = qa_pair.get("question", "").strip().lower()
if question and question not in seen_questions:
seen_questions.add(question)
unique_qa_pairs.append(qa_pair)
self.logger.debug(
f"Added unique question: {qa_pair['question']}")
else:
self.logger.debug(
f"Skipped duplicate question: {qa_pair.get('question', 'N/A')}")
qa_pairs = unique_qa_pairs
# self.logger.info(
# f"After deduplication: {len(qa_pairs)} unique QA pairs")
# 如果没有找到有效的JSON对象,记录警告
if not json_objects:
self.logger.warning(
"No JSON objects found in model response.")
except Exception as e:
self.logger.warning(
f"Failed to parse QA information from model response. Error: {e}")
return qa_pairs
def _calculate_complexity(self, qa_pairs: List[Dict[str, Any]]) -> float:
r"""Calculate the complexity score for a set of QA pairs.
Args:
qa_pairs (List[Dict[str, Any]]): List of QA pairs to calculate
complexity for.
Returns:
float: Complexity score between 0.0 and 1.0.
"""
if not qa_pairs:
return 0.0
# Calculate complexity based on multiple factors
complexities = []
for qa in qa_pairs:
# 1. Number of reasoning steps
reasoning_steps_count = len(qa.get('reasoning_steps', []))
# 2. Number of supporting facts
supporting_facts_count = len(qa.get('supporting_facts', []))
# 3. Question length
question_length = len(qa.get('question', '').split())
# 4. Answer length
answer_length = len(qa.get('answer', '').split())
# Calculate complexity of a single QA pair
qa_complexity = (
min(reasoning_steps_count / 3, 1.0)
* 0.4 # Weight for reasoning steps
+ min(supporting_facts_count / 3, 1.0)
* 0.3 # Weight for supporting facts
+ min(question_length / 20, 1.0)
* 0.15 # Weight for question length
+ min(answer_length / 50, 1.0) * 0.15
# Weight for answer length
)
complexities.append(qa_complexity)
return sum(complexities) / len(complexities)
from dataflow.prompts.kbcleaning import KnowledgeCleanerPrompt
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from typing import Union
import re
@prompt_restrict(
KnowledgeCleanerPrompt
)
@OPERATOR_REGISTRY.register()
class KBCTextCleaner(OperatorABC):
'''
KnowledgeCleaner is a class that cleans knowledge for RAG to make them more accurate, reliable and readable.
'''
def __init__(self, llm_serving: LLMServingABC, lang="en", prompt_template : Union[KnowledgeCleanerPrompt, DIYPromptABC] = None):
self.logger = get_logger()
self.prompts = KnowledgeCleanerPrompt(lang=lang)
self.llm_serving = llm_serving
if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = KnowledgeCleanerPrompt()
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"知识清洗算子:对原始知识内容进行标准化处理,包括HTML标签清理、特殊字符规范化、"
"链接处理和结构优化,提升RAG知识库的质量。主要功能:\n"
"1. 移除冗余HTML标签但保留语义化标签\n"
"2. 标准化引号/破折号等特殊字符\n"
"3. 处理超链接同时保留文本\n"
"4. 保持原始段落结构和代码缩进\n"
"5. 确保事实性内容零修改\n"
"\n输入格式示例:\n"
"<div class=\"container\">\n"
" <h1>标题文本</h1>\n"
" <p>正文段落,包括特殊符号,例如“弯引号”、–破折号等</p>\n"
" <img src=\"example.jpg\" alt=\"示意图\">\n"
" <a href=\"...\">链接文本</a>\n"
" <pre><code>代码片段</code></pre>\n"
" ...\n"
"</div>\n"
"\n输出格式示例:\n"
"标题文本\n\n"
"正文段落,包括特殊符号,例如\"直引号\"、-破折号等\n\n"
"[Image: 示例图 example.jpg]\n\n"
"链接文本\n\n"
"<code>代码片段</code>\n\n"
"[结构保持,语义保留,敏感信息脱敏处理(如手机号、保密标记等)]"
)
elif lang == "en":
return (
"Knowledge Cleaning Operator: Standardizes raw HTML/text content for RAG quality improvement. Key functions:\n"
"1. Removes redundant HTML tags while preserving semantic tags\n"
"2. Normalizes special characters (e.g., curly quotes, dashes)\n"
"3. Processes hyperlinks and retains their text\n"
"4. Preserves paragraph structure and code indentation\n"
"5. Ensures factual content remains unchanged\n"
"\nExample Input Format:\n"
"<div class=\"container\">\n"
" <h1>Title Text</h1>\n"
" <p>Paragraph with “curly quotes” and – dashes</p>\n"
" <img src=\"example.jpg\" alt=\"Diagram\">\n"
" <a href=\"...\">Link text</a>\n"
" <pre><code>Code block</code></pre>\n"
" ...\n"
"</div>\n"
"\nExample Output Format:\n"
"Title Text\n\n"
"Paragraph with \"straight quotes\" and - dashes\n\n"
"[Image: Diagram example.jpg]\n\n"
"Link text\n\n"
"<code>Code block</code>\n\n"
"[Structure retained, semantics preserved, sensitive info masked (e.g., phone numbers, confidential tags)]"
)
else:
return "Knowledge cleaning operator for RAG content standardization. Set lang='zh' or 'en' for examples."
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
"""
Reformat the prompts in the dataframe to generate questions.
"""
raw_contents = dataframe[self.input_key].tolist()
inputs = [self.prompt_template.build_prompt(raw_content) for raw_content in raw_contents]
return inputs
def run(
self,
storage: DataFlowStorage,
input_key:str = "raw_chunk",
output_key:str = "cleaned_chunk"
):
'''
Runs the knowledge cleaning process, reading from the input key and saving results to output key.
'''
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._reformat_prompt(dataframe)
cleaned = self.llm_serving.generate_from_input(formatted_prompts,"")
#for each in cleaned, only save the content in <cleaned_start> and <cleaned_end>
cleaned_extracted = [
str(text).split('<cleaned_start>')[1].split('<cleaned_end>')[0].strip()
if '<cleaned_start>' in str(text) and '<cleaned_end>' in str(text)
else str(text).strip()
for text in cleaned
]
dataframe[self.output_key] = cleaned_extracted
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [output_key]
\ No newline at end of file
from dataflow.prompts.kbcleaning import KnowledgeCleanerPrompt
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
import json
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from typing import Union
import re
@prompt_restrict(
KnowledgeCleanerPrompt
)
@OPERATOR_REGISTRY.register()
class KBCTextCleanerBatch(OperatorABC):
'''
KnowledgeCleaner is a class that cleans knowledge for RAG to make them more accurate, reliable and readable.
'''
def __init__(self, llm_serving: LLMServingABC, lang="en", prompt_template: Union[KnowledgeCleanerPrompt, DIYPromptABC] = None):
self.logger = get_logger()
self.prompts = KnowledgeCleanerPrompt(lang=lang)
self.llm_serving = llm_serving
if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = KnowledgeCleanerPrompt(lang=lang)
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"知识清洗算子:对原始知识内容进行标准化处理,包括HTML标签清理、特殊字符规范化、"
"链接处理和结构优化,提升RAG知识库的质量。主要功能:\n"
"1. 移除冗余HTML标签但保留语义化标签\n"
"2. 标准化引号/破折号等特殊字符\n"
"3. 处理超链接同时保留文本\n"
"4. 保持原始段落结构和代码缩进\n"
"5. 确保事实性内容零修改"
)
elif lang == "en":
return (
"Knowledge Cleaning Operator: Standardizes raw content for RAG by:\n"
"1. Removing redundant HTML tags while preserving semantic markup\n"
"2. Normalizing special characters (quotes/dashes)\n"
"3. Processing hyperlinks with text preservation\n"
"4. Maintaining original paragraph structure and code indentation\n"
"5. Guaranteeing zero modification of factual content"
)
else:
return "Knowledge cleaning operator for RAG content standardization"
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(
f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
"""
Reformat the prompts in the dataframe to generate questions.
"""
raw_contents = dataframe[self.input_key].tolist()
inputs = [self.prompt_template.build_prompt(
raw_content) for raw_content in raw_contents]
return inputs
def _reformat_prompt_from_path(self, chunk_path: str) -> list:
"""
Reformat the prompts in the file (JSON or JSONL) to generate question prompts.
Args:
chunk_path (str): Path to the .json or .jsonl file containing raw chunks.
Returns:
list: A list of formatted prompt strings.
"""
if chunk_path.endswith(".json"):
dataframe = pd.read_json(chunk_path)
elif chunk_path.endswith(".jsonl"):
dataframe = pd.read_json(chunk_path, lines=True)
else:
raise ValueError(
"Unsupported file format. Only .json and .jsonl are supported.")
if "raw_chunk" not in dataframe.columns:
raise KeyError("'raw_chunk' field not found in the input file.")
raw_contents = dataframe["raw_chunk"].tolist()
inputs = [self.prompts.build_prompt(
raw_content) for raw_content in raw_contents]
return raw_contents, inputs
def run(
self,
storage: DataFlowStorage,
input_key: str = "chunk_path",
output_key: str = "cleaned_chunk_path"
):
'''
Runs the knowledge cleaning process, reading from the input key and saving results to output key.
'''
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
chunk_paths = dataframe[self.input_key].tolist()
for chunk_path in chunk_paths:
if(chunk_path):
raw_chunks, formatted_prompts = self._reformat_prompt_from_path(chunk_path)
cleaned = self.llm_serving.generate_from_input(formatted_prompts, "")
# for each in cleaned, only save the content in <cleaned_start> and <cleaned_end>
cleaned_extracted = [
text.split('<cleaned_start>')[1].split('<cleaned_end>')[0].strip()
if '<cleaned_start>' in str(text) and '<cleaned_end>' in str(text)
else str(text).strip()
for text in cleaned
]
json_items=[{
"raw_chunk": raw_chunk,
"cleaned_chunk": cleaned_chunk
} for raw_chunk, cleaned_chunk in zip(raw_chunks, cleaned_extracted)]
with open(chunk_path, "w", encoding="utf-8") as f:
json.dump(json_items, f, ensure_ascii=False, indent=4)
self.logger.info(f"Successfully cleaned contents in {chunk_path}")
dataframe[self.output_key] = chunk_paths
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [output_key]
import sys
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
import os
from pathlib import Path
import json
import shutil
import fitz # pip install pymupdf
from dataflow.prompts.kbcleaning import MathbookQuestionExtractPrompt
import re
from openai import OpenAI
import base64
from typing import Literal, Union
from dataflow.core import LLMServingABC
from dataflow.serving import APIVLMServing_openai
from dataflow.core.prompt import DIYPromptABC
from dataflow.utils.storage import DataFlowStorage
@OPERATOR_REGISTRY.register()
class MathBookQuestionExtract(OperatorABC):
def __init__(self,
llm_serving: APIVLMServing_openai,
prompt_template: Union[MathbookQuestionExtractPrompt, DIYPromptABC] = MathbookQuestionExtractPrompt(),
mineru_backend: str = "vlm-vllm-engine",
dpi: int = 300,
key_name_of_api_key: str = "DF_API_KEY",
model_name: str = "o4-mini",
max_workers: int = 20
):
self.logger = get_logger()
self.llm_serving = llm_serving
self.prompt_template = prompt_template
self.mineru_backend = mineru_backend
self.dpi = dpi
self.key_name_of_api_key = key_name_of_api_key
self.model_name = model_name
self.max_workers = max_workers # 注意:这个参数在原逻辑中并未被使用,但仍按要求移入init
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于从数学教材PDF中提取问题和相关图片内容。它将PDF转换为图片,使用MinerU进行内容提取,"
"然后组织图片并使用大语言模型分析内容,最终生成包含问题和图片的JSON和Markdown文件。\n"
"输入参数:\n"
"- llm_serving:VLM服务对象,需实现APIVLMServing_openai接口\n"
"- pdf_file_path:PDF文件路径\n"
"- output_file_name:输出文件名\n"
"- output_folder:输出文件夹路径\n"
"- MinerU_Backend:MinerU后端类型,默认为'vlm-sglang-engine'\n"
"- dpi:PDF转图片的分辨率,默认为300\n"
"- api_url:API服务URL\n"
"- key_name_of_api_key:API密钥的环境变量名\n"
"- model_name:使用的模型名称,默认为'o4-mini'\n"
"- max_workers:最大并行工作线程数,默认为20\n"
"输出参数:\n"
"- 返回布尔值表示处理是否成功\n"
"- 在指定文件夹生成JSON和Markdown格式的提取结果"
)
elif lang == "en":
return (
"This operator extracts questions and related images from mathematics textbook PDFs. It converts the PDF to images, "
"uses MinerU for content extraction, organizes the images, and analyzes the content using a large vision-language model, "
"ultimately generating JSON and Markdown files containing questions and images.\n"
"Input Parameters:\n"
"- llm_serving: VLM serving object implementing APIVLMServing_openai interface\n"
"- pdf_file_path: Path to the PDF file\n"
"- output_file_name: Name for the output files\n"
"- output_folder: Path to the output folder\n"
"- MinerU_Backend: MinerU backend type, default is 'vlm-sglang-engine'\n"
"- dpi: Resolution for PDF to image conversion, default is 300\n"
"- api_url: API service URL\n"
"- key_name_of_api_key: Environment variable name for API key\n"
"- model_name: Model name to use, default is 'o4-mini'\n"
"- max_workers: Maximum number of parallel workers, default is 20\n\n"
"Output Parameters:\n"
"- Returns boolean indicating success of processing\n"
"- Generates extraction results in JSON and Markdown formats in the specified folder"
)
else:
return (
"MathBookQuestionExtract processes mathematics textbook PDFs to extract questions and images using MinerU and VLM."
)
def mineru2_runner(self,
pdf_file_path:str,
output_folder:str,
# pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client
mineru_backend: Literal["pipeline", "vlm-transformers", "vlm-vllm-engine", "vlm-http-client"] = "pipeline"
):
try:
import mineru
except ImportError:
raise Exception(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
os.environ['MINERU_MODEL_SOURCE'] = "local" # 可选:从本地加载模型
MinerU_Version = {"pipeline": "auto", "vlm-transformers": "vlm", 'vlm-vllm-engine': 'vlm', 'vlm-http-client': 'vlm'}
raw_file = Path(pdf_file_path)
pdf_name = raw_file.stem
intermediate_dir = output_folder
try:
return_code = os.system(
f"mineru -p {raw_file} -o {intermediate_dir} -b {mineru_backend} --source local"
)
if return_code != 0:
raise RuntimeError(f"MinerU execution failed with return code: {return_code}")
except Exception as e:
raise RuntimeError(f"Failed to process file with MinerU: {str(e)}")
output_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_content_list.json")
output_pic_folder = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], "images")
self.logger.info(f"MinerU json file has been saved to {output_file}")
return output_file, output_pic_folder
def organize_pics(
self,
mineru_content_json_path: str,
mineru_image_folder: str,
output_file_path: str,
output_pic_folder: str
):
'''
用来把mineru切割出来的图片组织到最终文件夹下的辅助函数
输入:
mineru_content_json_path: mineru切割出来的json文件路径
mineru_image_folder: mineru切割出来的图片文件夹路径
输出:
output_file_path: 组织图片后的图片信息记录文件,服务后续的图片处理
output_pic_folder: 最终组织后的图片文件夹路径
'''
global_counter = 0
global_json_data = []
# read mineru content json
json_data = json.load(open(mineru_content_json_path, 'r'))
# if output_pic_folder is not exist, create it
if not os.path.exists(output_pic_folder):
os.makedirs(output_pic_folder)
for item in json_data:
if item['type'] == 'image':
# get the image name
image_name = item['img_path'].split('/')[-1]
# get the image path
image_path = os.path.join(mineru_image_folder, image_name)
page_idx = item['page_idx']
# rename the image
new_image_name = f"{global_counter}.jpg"
new_image_path = os.path.join(output_pic_folder, new_image_name)
shutil.copy(image_path, new_image_path)
# add to global json data
global_json_data.append({
"img_path": new_image_path,
"page_idx": page_idx,
})
global_counter += 1
# write to json file
with open(output_file_path, 'w') as f:
json.dump(global_json_data, f, indent=4)
def pdf2images(self, pdf_path: str, output_folder: str, dpi: int = 300):
'''
用来把pdf文件转换为图片的辅助函数
输入:
pdf_path: pdf文件路径
output_folder: 输出图片文件夹路径
'''
doc = fitz.open(pdf_path)
# make output directory if it doesn't exist
os.makedirs(output_folder, exist_ok=True)
# convert each page to image
for page_index in range(len(doc)):
page = doc.load_page(page_index)
pix = page.get_pixmap(dpi=dpi)
pix.save(f"{output_folder}/page_{page_index}.jpg")
self.logger.info(f"Converted page {page_index} to image")
return True
def encode_image_to_base64(self, image_path: str) -> str:
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def process_input(self,
page_folder: str,
img_json_path: str
):
# 加载page_folder内所有的page_n.jpg
page_list = [os.path.join(page_folder, f) for f in os.listdir(page_folder) if f.endswith(('.jpg'))]
idx_list = [int(f.split("/")[-1].split(".")[0].split("_")[-1]) for f in page_list]
max_page_idx = max(idx_list)
# load img_json
img_json = json.load(open(img_json_path, "r"))
img_dict = {}
for item in img_json:
if item["page_idx"] not in img_dict:
img_dict[item["page_idx"]] = []
img_dict[item["page_idx"]].append(item["img_path"])
full_input_image_list = []
full_input_label_list = []
for page_idx in range(max_page_idx):
image_list = []
label_list = []
image_list.append(os.path.join(page_folder, f"page_{page_idx}.jpg"))
label_list.append(f"page_{page_idx}")
image_list.append(os.path.join(page_folder, f"page_{page_idx+1}.jpg"))
label_list.append(f"page_{page_idx+1}")
if page_idx in img_dict:
image_list.extend(img_dict[page_idx])
label_list.extend([img_path.split("/")[-1] for img_path in img_dict[page_idx]])
if page_idx+1 in img_dict:
image_list.extend(img_dict[page_idx+1])
label_list.extend([img_path.split("/")[-1] for img_path in img_dict[page_idx+1]])
full_input_image_list.append(image_list)
full_input_label_list.append(label_list)
return full_input_image_list,full_input_label_list
def analyze_and_save(self,result_list,save_folder,img_folder,output_file_name):
# ... (analyze_and_save 方法保持不变)
# make save_folder if not exist
if not os.path.exists(save_folder):
os.makedirs(save_folder)
# make save_folder/images if not exist
if not os.path.exists(os.path.join(save_folder, "images")):
os.makedirs(os.path.join(save_folder, "images"))
output_json = []
output_markdown_text = ""
for item in result_list:
if not item:
continue
split_text = item.split("<SPACE>")
for text in split_text:
if not text:
continue
# 检查所有形如<image>index.jpg</image>这样的内容,比如<image>1.jpg</image>,严格匹配<image>*.jpg</image>
pic_list = []
pic_match = re.findall(r'<image>(.*?)\.jpg</image>', text)
if pic_match:
for pic_name in pic_match:
# 传入完整路径
pic_list.append(os.path.join(img_folder, f"{pic_name}.jpg"))
# 生成json风格tezt:直接删掉所有<image>index.jpg</image>
json_text = re.sub(r'<image>(.*?)\.jpg</image>', '', text)
# 生成markdown风格text:把<image>index.jpg</image>替换为![index.jpg](img_folder/index.jpg)
markdown_text = text
for pic_name in pic_match:
# 把img_folder/pic_name.jpg copy 到 save_folder/images/pic_name.jpg
shutil.copy(os.path.join(img_folder, f"{pic_name}.jpg"), os.path.join(save_folder, "images", f"{pic_name}.jpg"))
markdown_text = markdown_text.replace(f"<image>{pic_name}.jpg</image>", f"![](images/{pic_name}.jpg)")
else:
json_text = text
markdown_text = text
pic_list = []
json_data = {
"text": json_text,
"pics": pic_list
}
output_json.append(json_data)
output_markdown_text += markdown_text + "\n" + "---" + "\n"
# save output_json to save_folder
with open(os.path.join(save_folder, f"{output_file_name}.json"), "w") as f:
json.dump(output_json, f, indent=4, ensure_ascii=False)
# save output_markdown_text to save_folder
with open(os.path.join(save_folder, f"{output_file_name}.md"), "w", encoding="utf-8") as f:
f.write(output_markdown_text)
return output_json,output_markdown_text
def run(
self,
storage: DataFlowStorage,
input_pdf_file_path: str,
output_file_name: str,
output_folder: str,
):
# get the configuration parameters from self
api_key = os.environ.get(self.key_name_of_api_key)
if not api_key:
raise ValueError(f"API key not found in environment variable {self.key_name_of_api_key}")
# 1. convert pdf to images
pdf2images_folder_name = os.path.join(output_folder, "pdfimages")
self.pdf2images(input_pdf_file_path, pdf2images_folder_name, self.dpi)
# 2. use mineru to extract content and pics
json_content_file, pic_folder = self.mineru2_runner(input_pdf_file_path, output_folder, self.mineru_backend)
# 3. organize_pics
output_image_folder = os.path.join(output_folder, "organized_images")
output_json_file = os.path.join(output_image_folder, "organized_info.json")
self.organize_pics(json_content_file, pic_folder, output_json_file, output_image_folder)
# 4. process input
full_input_image_list, full_input_label_list = self.process_input(pdf2images_folder_name, output_json_file)
# 5. init server and generate
system_prompt = self.prompt_template.build_prompt()
result_text_list = self.llm_serving.generate_from_input_multi_images(
list_of_image_paths=full_input_image_list,
list_of_image_labels=full_input_label_list,
system_prompt=system_prompt,
model=self.model_name,
timeout=1800
)
# 6. save responses
self.analyze_and_save(result_text_list, output_folder, output_image_folder, output_file_name)
# 7. return
return True
#!/usr/bin/env python3
"""QA Extractor - 提取QA对并转换为Alpaca格式"""
import json
from pathlib import Path
from typing import Optional, List
from dataflow.core import OperatorABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow import get_logger
@OPERATOR_REGISTRY.register()
class QAExtractor(OperatorABC):
"""
从QA_pairs字段提取问答对,转换为Alpaca微调格式
Input: QA_pairs (nested structure)
Output: instruction, input, output (Alpaca format)
"""
def __init__(
self,
qa_key: str = "QA_pairs",
output_json_file: Optional[str] = None,
instruction: str = "Please answer the following question based on the provided information."
):
self.logger = get_logger()
self.qa_key = qa_key
self.output_json_file = output_json_file
self.instruction = instruction
@staticmethod
def get_desc(lang: str = "zh"):
"""获取算子描述"""
if lang == "zh":
return (
"QA对提取器 - 将嵌套的QA_pairs转换为Alpaca微调格式\n\n"
"核心功能:\n"
"从结构化的QA对数据中提取问答内容,自动整合推理步骤和支持事实,\n"
"输出符合Stanford Alpaca标准的instruction-input-output格式。\n\n"
"初始化参数:\n"
"• qa_key: QA对的字段名 (默认: 'QA_pairs')\n"
"• output_json_file: 输出JSON文件路径 (可选,不指定则只更新DataFrame)\n"
"• instruction: 统一的指令前缀 (默认: 'Please answer the following question...')\n\n"
"运行参数 (input_key):\n"
"• None - 包含所有字段 (question + reasoning_steps + supporting_facts)\n"
"• '' - 空字符串,不包含额外上下文\n"
"• 'reasoning_steps' - 只包含推理步骤\n"
"• 'question,reasoning_steps' - 逗号分隔多个字段\n"
"• ['question', 'supporting_facts'] - 列表格式\n\n"
"输出字段:\n"
"• instruction: 问题指令\n"
"• input: 上下文信息 (根据input_key动态拼接)\n"
"• output: 答案\n\n"
"适用场景: 知识库QA微调、领域问答模型训练"
)
else: # English
return (
"QA Extractor - Convert nested QA_pairs to Alpaca fine-tuning format\n\n"
"Core Function:\n"
"Extract question-answer pairs from structured data, automatically integrate\n"
"reasoning steps and supporting facts, output in Stanford Alpaca standard\n"
"instruction-input-output format.\n\n"
"Initialization Parameters:\n"
"• qa_key: Field name for QA pairs (default: 'QA_pairs')\n"
"• output_json_file: Output JSON path (optional, skip to only update DataFrame)\n"
"• instruction: Unified instruction prefix (default: 'Please answer...')\n\n"
"Runtime Parameters (input_key):\n"
"• None - Include all fields (question + reasoning_steps + supporting_facts)\n"
"• '' - Empty string, no additional context\n"
"• 'reasoning_steps' - Only reasoning steps\n"
"• 'question,reasoning_steps' - Comma-separated fields\n"
"• ['question', 'supporting_facts'] - List format\n\n"
"Output Fields:\n"
"• instruction: Question as instruction\n"
"• input: Context information (dynamically assembled by input_key)\n"
"• output: Answer\n\n"
"Use Cases: Knowledge base QA fine-tuning, domain-specific Q&A training"
)
def _parse_fields(self, input_key: Optional[str]) -> Optional[List[str]]:
"""解析要包含的字段"""
if input_key is None:
return None # 包含所有
if isinstance(input_key, list):
return input_key
if isinstance(input_key, str):
return [f.strip() for f in input_key.split(',') if f.strip()] if input_key.strip() else []
return None
def _extract_qa(self, row, fields: Optional[List[str]] = None) -> List[dict]:
"""从单行提取QA对"""
qa_data = row.get(self.qa_key)
if not qa_data:
return []
# 支持嵌套结构
qa_list = qa_data.get('qa_pairs', []) if isinstance(qa_data, dict) else qa_data
if not isinstance(qa_list, list):
return []
results = []
default_fields = ['question', 'reasoning_steps', 'supporting_facts']
fields = fields if fields is not None else default_fields
for qa in qa_list:
if not isinstance(qa, dict):
continue
question = qa.get('question', '').strip()
answer = qa.get('answer', '').strip()
if not question or not answer:
continue
# 构建input
parts = []
for field in fields:
if field == 'question':
parts.append(f"Question: {question}")
elif field == 'reasoning_steps' and qa.get('reasoning_steps'):
if parts:
parts.append("")
parts.append("Reasoning Process:")
for i, step in enumerate(qa['reasoning_steps'], 1):
text = step.get('step', step) if isinstance(step, dict) else str(step)
if text:
parts.append(f"{i}. {text}")
elif field == 'supporting_facts' and qa.get('supporting_facts'):
if parts:
parts.append("")
parts.append("Supporting Information:")
for fact in qa['supporting_facts']:
text = fact.get('fact', fact) if isinstance(fact, dict) else str(fact)
if text:
parts.append(f"- {text}")
elif field in qa and qa[field]:
if parts:
parts.append("")
parts.append(f"{field}: {qa[field]}")
results.append({
'instruction': self.instruction,
'input': "\n".join(parts),
'output': answer
})
return results
def _load_from_files(self, df):
"""从chunk文件加载QA数据"""
import pandas as pd
path_keys = ['enhanced_chunk_path', 'cleaned_chunk_path', 'chunk_path']
path_col = next((k for k in path_keys if k in df.columns), None)
if not path_col:
raise ValueError(f"需要这些字段之一: {path_keys}")
rows = []
for _, row in df.iterrows():
file_path = row[path_col]
if not file_path or not Path(file_path).exists():
continue
try:
with open(file_path, 'r', encoding='utf-8') as f:
chunks = json.load(f)
chunks = chunks if isinstance(chunks, list) else [chunks]
for chunk in chunks:
if self.qa_key in chunk:
rows.append({
self.qa_key: chunk[self.qa_key],
'source_file': file_path
})
except Exception as e:
self.logger.error(f"加载失败 {file_path}: {e}")
if not rows:
raise ValueError("未找到有效QA数据")
return pd.DataFrame(rows)
def run(
self,
storage: DataFlowStorage,
input_key: Optional[str] = None,
output_key: Optional[str] = None
) -> List[str]:
"""提取QA对"""
import pandas as pd
self.logger.info("开始提取QA对...")
df = storage.read(output_type="dataframe")
# 如果没有QA_pairs,从文件加载
if self.qa_key not in df.columns:
df = self._load_from_files(df)
# 提取所有QA对
fields = self._parse_fields(input_key)
all_qas = []
for _, row in df.iterrows():
all_qas.extend(self._extract_qa(row, fields))
self.logger.info(f"提取了 {len(all_qas)} 个QA对")
if not all_qas:
self.logger.warning("未提取到QA对!")
return ['instruction', 'input', 'output']
# 保存JSON(可选)
if self.output_json_file:
output_path = Path(self.output_json_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(all_qas, f, indent=2, ensure_ascii=False)
self.logger.info(f"已保存到 {output_path}")
# 写回storage
storage.write(pd.DataFrame(all_qas))
return ['instruction', 'input', 'output']
\ No newline at end of file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .generate.vqa_extractor import VQAExtractor
else:
import sys
from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking
cur_path = "dataflow/operators/pdf2vqa/"
_import_structure = generate_import_structure_from_type_checking(__file__, cur_path)
sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/pdf2vqa/", _import_structure)
import os
import json
import re
import pandas as pd
import tiktoken
import shutil
import torch
from pathlib import Path
from typing import Literal
from dataflow.core import OperatorABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow import get_logger
from dataflow.core import LLMServingABC
from dataflow.prompts.pdf2vqa import QAExtractPrompt
from dataflow.core.prompt import prompt_restrict
from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md
@prompt_restrict(QAExtractPrompt)
@OPERATOR_REGISTRY.register()
class VQAExtractor(OperatorABC):
def __init__(self,
llm_serving: LLMServingABC = None,
mineru_backend: Literal["vlm-transformers","vlm-vllm-engine"] = "vlm-transformers",
max_chunk_len: int = 128000,):
self.logger = get_logger()
self.llm_serving = llm_serving
self.prompt_template = QAExtractPrompt()
self.mineru_backend = mineru_backend
self.max_chunk_len = max_chunk_len
def _convert_json(self, input_file, output_file):
with open(input_file, 'r') as infile:
data = list(json.load(infile))
new_data = []
id = 0
for item in data:
item['id'] = id
item.pop('bbox', None)
item.pop('page_idx', None)
if item.get('type','') == 'list':
if item['sub_type'] == 'text':
for idx, list_item in enumerate(item.get('list_items', [])):
new_item = {
'type': 'text',
'text': list_item,
'id': id + idx,
}
new_data.append(new_item)
id += len(item.get('list_items', []))
else:
new_data.append(item)
id += 1
with open(output_file, 'w') as outfile:
json.dump(new_data, outfile, ensure_ascii=False)
def _count_tokens(self, text: str) -> int:
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(text))
def _id_to_text(self, input_ids, input_json, image_prefix="images"):
texts = []
id_list = input_ids.replace(' ', '').split(',')
for id in id_list:
try:
int(id)
except:
continue
if int(id) < len(input_json):
try:
item = input_json[int(id)]
except:
continue
if 'text' in item:
texts.append(item['text'])
elif 'img_path' in item:
try:
img_path = item.get('img_path', '')
img_name = os.path.basename(img_path)
new_path = f"{image_prefix}/{img_name}"
texts.append(f"![{' '.join(item.get('image_caption','image'))}]({new_path})")
except:
pass
elif item.get('type','') == 'list':
if item['sub_type'] == 'text':
try:
texts.append(input_json[int(id)]['list_items'].pop(0))
except:
pass
return '\n'.join(texts)
def _extract_doc_layout(self, input_pdf_file_path: str, output_folder: str, mineru_backend: Literal["vlm-transformers","vlm-vllm-engine"] = "vlm-transformers"):
"""提取 PDF 的布局信息(合并自 VQAExtractDocLayoutMinerU)"""
try:
import mineru
from mineru.cli.client import main as mineru_main
except ImportError:
raise Exception(
"""
MinerU is not installed in this environment yet.
Please refer to https://github.com/opendatalab/mineru to install.
Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
Please make sure you have GPU on your machine.
"""
)
try:
from pypdf import PdfReader, PdfWriter, PageObject
except ImportError:
raise Exception(
"""
pypdf is not installed in this environment yet.
Please use pip install pypdf.
"""
)
try:
from reportlab.pdfgen import canvas
except ImportError:
raise Exception(
"""
reportlab is not installed in this environment yet.
Please use pip install reportlab.
"""
)
os.environ['MINERU_MODEL_SOURCE'] = "local"
MinerU_Version = {"pipeline": "auto", "vlm-transformers": "vlm", "vlm-vllm-engine": "vlm"}
if mineru_backend == "pipeline":
raise ValueError("The 'pipeline' backend is not supported due to its incompatible output format. Please use 'vlm-transformers' or 'vlm-vllm-engine' instead.")
raw_file = Path(input_pdf_file_path)
pdf_name = raw_file.stem
intermediate_dir = output_folder
args = [
"-p", str(raw_file),
"-o", str(intermediate_dir),
"-b", mineru_backend,
"--source", "local"
]
if mineru_backend == "vlm-vllm-engine":
assert torch.cuda.is_available(), "MinerU vlm-vllm-engine backend requires GPU support."
args += ["--tensor-parallel-size", "2" if torch.cuda.device_count() >= 2 else "1"]
try:
mineru_main(args)
except SystemExit as e:
if e.code != 0:
raise RuntimeError(f"MinerU execution failed with exit code: {e.code}")
output_json_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_content_list.json")
output_layout_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_layout.pdf")
return output_json_file, output_layout_file
def _convert_response(self, input_response, input_json_path, image_prefix="images"):
qa_list = []
with open(input_json_path, 'r') as infile:
input_json = list(json.load(infile))
# 提取title
for chapter_block in re.findall(r'<chapter>(.*?)</chapter>', input_response, flags=re.DOTALL):
title = re.search(r'<title>(.*?)</title>', chapter_block, flags=re.DOTALL)
if title:
chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix)
else:
chapter_title = ""
# 找出所有 qa_pair 块
for pair in re.findall(r'<qa_pair>(.*?)</qa_pair>', chapter_block, flags=re.DOTALL):
# 提取 question 部分
q_match = re.search(r'<question>(.*?)</question>', pair, flags=re.DOTALL)
# 提取 answer 部分
a_match = re.search(r'<answer>(.*?)</answer>', pair, flags=re.DOTALL)
# 提取solution部分
s_match = re.search(r'<solution>(.*?)</solution>', pair, flags=re.DOTALL)
# 提取label
label_match = re.search(r'<label>(.*?)</label>', pair, flags=re.DOTALL)
if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)):
continue
label = label_match.group(1).strip()
qa_list.append({
'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "",
'answer': a_match.group(1).strip() if a_match else "",
'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "",
'label': label,
'chapter_title': chapter_title
})
return qa_list
def run(self, storage: DataFlowStorage,
input_question_pdf_path_key: str = "question_pdf_path",
input_answer_pdf_path_key: str = "answer_pdf_path",
input_pdf_path_key: str = "pdf_path", # 支持 interleaved 模式的单一 pdf_path
input_subject_key: str = "subject",
output_dir_key: str = "output_dir",
output_jsonl_key: str = "output_jsonl_path",
output_default_dir: str = "../vqa_output") -> list:
dataframe = storage.read("dataframe")
# 支持两种输入格式:question_pdf_path/answer_pdf_path 或 pdf_path
if input_question_pdf_path_key not in dataframe.columns and input_pdf_path_key not in dataframe.columns:
raise ValueError(f"Column '{input_question_pdf_path_key}' or '{input_pdf_path_key}' not found in dataframe")
# ========== Stage 1: 预处理(任务扩展 + Layout 提取) ==========
expanded_rows = []
for idx, row in dataframe.iterrows():
# 优先使用 question_pdf_path,如果没有则使用 pdf_path(interleaved 模式)
if input_question_pdf_path_key in dataframe.columns:
question_pdf_path = row[input_question_pdf_path_key]
answer_pdf_path = row.get(input_answer_pdf_path_key, question_pdf_path)
else:
# interleaved 模式:使用同一个 pdf_path
question_pdf_path = row[input_pdf_path_key]
answer_pdf_path = question_pdf_path
subject = row.get(input_subject_key, "math")
output_root = row.get(output_dir_key, output_default_dir)
interleaved = (question_pdf_path == answer_pdf_path)
os.makedirs(output_root, exist_ok=True)
# Question task
q_outdir = os.path.join(output_root, "question")
os.makedirs(q_outdir, exist_ok=True)
# Layout 提取
q_json_path, _ = self._extract_doc_layout(
input_pdf_file_path=question_pdf_path,
output_folder=q_outdir,
mineru_backend=self.mineru_backend
)
expanded_rows.append({
"pdf_path": question_pdf_path,
"mode": "question",
"interleaved": interleaved,
"subject": subject,
"output_dir": q_outdir,
"output_root": output_root,
"json_path": q_json_path
})
# Answer task (if not interleaved)
if not interleaved:
a_outdir = os.path.join(output_root, "answer")
os.makedirs(a_outdir, exist_ok=True)
# Layout 提取
a_json_path, _ = self._extract_doc_layout(
input_pdf_file_path=answer_pdf_path,
output_folder=a_outdir,
mineru_backend=self.mineru_backend
)
expanded_rows.append({
"pdf_path": answer_pdf_path,
"mode": "answer",
"interleaved": interleaved,
"subject": subject,
"output_dir": a_outdir,
"output_root": output_root,
"json_path": a_json_path
})
# ========== Stage 2: QA 提取 ==========
json_paths = [row["json_path"] for row in expanded_rows]
subjects = [row["subject"] for row in expanded_rows]
user_inputs = []
split_metadata = []
for idx, input_json_path in enumerate(json_paths):
subject = subjects[idx] if idx < len(subjects) else subjects[0] if subjects else "math"
system_prompt = self.prompt_template.build_prompt(subject)
system_prompt_len = self._count_tokens(system_prompt)
converted_path = input_json_path.replace('.json', '_converted.json')
self._convert_json(input_json_path, converted_path)
with open(converted_path, 'r') as infile:
data = json.load(infile)
assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"
# 分段处理
current_chunk, current_len = [], system_prompt_len
chunks = []
for item in data:
text = json.dumps(item, ensure_ascii=False)
item_len = self._count_tokens(text)
if current_len + item_len > self.max_chunk_len and current_chunk:
chunks.append(current_chunk)
current_chunk, current_len = [], system_prompt_len
current_chunk.append(item)
current_len += item_len
if current_chunk:
chunks.append(current_chunk)
split_metadata.append(len(chunks))
for chunk in chunks:
user_inputs.append({
'user_input': json.dumps(chunk, ensure_ascii=False),
'system_prompt': system_prompt
})
# 批量生成
responses = [None] * len(user_inputs)
current_batch = []
current_batch_indices = []
current_system_prompt = None
for idx, item in enumerate(user_inputs):
user_input = item['user_input']
system_prompt = item['system_prompt']
if current_system_prompt is None:
current_system_prompt = system_prompt
current_batch = [user_input]
current_batch_indices = [idx]
elif system_prompt == current_system_prompt:
current_batch.append(user_input)
current_batch_indices.append(idx)
else:
# 处理当前批次
batch_responses = self.llm_serving.generate_from_input(user_inputs=current_batch, system_prompt=current_system_prompt)
for batch_idx, resp in zip(current_batch_indices, batch_responses):
responses[batch_idx] = resp
# 开始新批次
current_system_prompt = system_prompt
current_batch = [user_input]
current_batch_indices = [idx]
# 处理最后一批
if current_batch:
batch_responses = self.llm_serving.generate_from_input(user_inputs=current_batch, system_prompt=current_system_prompt)
for batch_idx, resp in zip(current_batch_indices, batch_responses):
responses[batch_idx] = resp
# 按 split_metadata 还原
recombined_responses = []
idx = 0
for num_chunks in split_metadata:
merged_text = "\n".join(responses[idx: idx + num_chunks])
recombined_responses.append(merged_text)
idx += num_chunks
# ========== Stage 3: 后处理(Response 转换 + 合并和过滤) ==========
# Response 转换
qa_lists = []
for idx, (response, row) in enumerate(zip(recombined_responses, expanded_rows)):
json_path = row["json_path"]
output_dir = row["output_dir"]
mode = row["mode"]
output_root = row["output_root"]
image_prefix = f"{mode}_images"
converted_json_path = json_path.replace('.json', '_converted.json')
qa_list = self._convert_response(response, converted_json_path, image_prefix)
# 复制图片
src_dir = os.path.join(output_dir, Path(json_path).stem).replace('_content_list','')
src_images = os.path.join(src_dir, 'vlm', 'images')
dst_images = os.path.join(output_root, image_prefix)
try:
if os.path.exists(src_images):
if os.path.exists(dst_images):
shutil.rmtree(dst_images)
shutil.copytree(src_images, dst_images)
else:
self.logger.warning(f"Source images dir does not exist: {src_images}")
except Exception as e:
self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}")
qa_lists.append(qa_list)
# 按 output_root 分组处理合并和过滤
output_groups = {}
for idx, (qa_list, row) in enumerate(zip(qa_lists, expanded_rows)):
output_root = row["output_root"]
mode = row["mode"]
interleaved = row["interleaved"]
output_dir = row["output_dir"]
if output_root not in output_groups:
output_groups[output_root] = {
"question": None,
"answer": None,
"interleaved": interleaved
}
if mode == "question":
output_groups[output_root]["question"] = (qa_list, output_dir)
elif mode == "answer":
output_groups[output_root]["answer"] = (qa_list, output_dir)
# 处理每个 output_root
result_paths_dict = {}
for output_root, group_info in output_groups.items():
q_qa_list, q_output_dir = group_info["question"] if group_info["question"] else (None, None)
a_qa_list, a_output_dir = group_info["answer"] if group_info["answer"] else (None, None)
interleaved = group_info["interleaved"]
# 写入 question jsonl
q_jsonl_path = os.path.join(output_root, "vqa_extracted_questions.jsonl")
if q_qa_list:
with open(q_jsonl_path, 'w', encoding='utf-8') as f:
for item in q_qa_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 写入 answer jsonl(如果不是 interleaved)
a_jsonl_path = None
if not interleaved and a_qa_list:
a_jsonl_path = os.path.join(output_root, "vqa_extracted_answers.jsonl")
with open(a_jsonl_path, 'w', encoding='utf-8') as f:
for item in a_qa_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 合并
merged_jsonl = os.path.join(output_root, "vqa_merged_qa_pairs.jsonl")
if not interleaved and a_jsonl_path:
merge_qa_pair(q_jsonl_path, a_jsonl_path, merged_jsonl)
else:
os.system(f"cp {q_jsonl_path} {merged_jsonl}")
# 过滤
filtered_items = []
total_count = 0
with open(merged_jsonl, 'r', encoding='utf-8') as f:
for line in f:
total_count += 1
item = json.loads(line)
if item.get('question','').strip() and (item.get('answer','').strip() or item.get('solution','').strip()):
filtered_items.append(item)
self.logger.info(f"Before filter: {total_count}, After filter: {len(filtered_items)}")
filtered_jsonl = os.path.join(output_root, "vqa_filtered_qa_pairs.jsonl")
with open(filtered_jsonl, 'w', encoding='utf-8') as f:
for item in filtered_items:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 转换为 markdown
md_output = os.path.join(output_root, "vqa_filtered_qa_pairs.md")
jsonl_to_md(filtered_jsonl, md_output)
result_paths_dict[output_root] = filtered_jsonl
# 为原始 dataframe 的每一行分配结果路径
result_paths = []
for idx, row in dataframe.iterrows():
if input_question_pdf_path_key in dataframe.columns:
question_pdf_path = row[input_question_pdf_path_key]
answer_pdf_path = row.get(input_answer_pdf_path_key, question_pdf_path)
else:
question_pdf_path = row[input_pdf_path_key]
answer_pdf_path = question_pdf_path
output_root = row.get(output_dir_key, output_default_dir)
result_paths.append(result_paths_dict.get(output_root))
dataframe[output_jsonl_key] = result_paths
output_file = storage.write(dataframe)
self.logger.info(f"VQA extraction complete. Results saved to {output_file}")
return [output_jsonl_key,]
from typing import TYPE_CHECKING
if TYPE_CHECKING:
# generate
from .generate.reasoning_answer_generator import ReasoningAnswerGenerator
from .generate.reasoning_question_generator import ReasoningQuestionGenerator
from .generate.reasoning_answer_extraction_qwenmatheval_generator import ReasoningAnswerExtractionQwenMathEvalGenerator
from .generate.reasoning_pseudo_answer_generator import ReasoningPseudoAnswerGenerator
from .generate.reasoning_pretrain_format_convert_generator import ReasoningPretrainFormatConvertGenerator
from .generate.reasoning_question_fusion_generator import ReasoningQuestionFusionGenerator
# eval
from .eval.reasoning_category_dataset_evaluator import ReasoningCategoryDatasetEvaluator
from .eval.reasoning_difficulty_dataset_evaluator import ReasoningDifficultyDatasetEvaluator
from .eval.reasoning_token_dataset_evaluator import ReasoningTokenDatasetEvaluator
from .eval.reasoning_question_category_sample_evaluator import ReasoningQuestionCategorySampleEvaluator
from .eval.reasoning_question_difficulty_sample_evaluator import ReasoningQuestionDifficultySampleEvaluator
from .eval.reasoning_question_solvable_sample_evaluator import ReasoningQuestionSolvableSampleEvaluator
# filter
from .filter.reasoning_answer_formatter_filter import ReasoningAnswerFormatterFilter
from .filter.reasoning_answer_groundtruth_filter import ReasoningAnswerGroundTruthFilter
from .filter.reasoning_answer_ngram_filter import ReasoningAnswerNgramFilter
from .filter.reasoning_answer_pipeline_root_filter import ReasoningAnswerPipelineRootFilter
from .filter.reasoning_answer_token_length_filter import ReasoningAnswerTokenLengthFilter
from .filter.reasoning_question_filter import ReasoningQuestionFilter
from .filter.reasoning_answer_model_judge_filter import ReasoningAnswerModelJudgeFilter
else:
import sys
from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking
cur_path = "dataflow/operators/reasoning/"
_import_structure = generate_import_structure_from_type_checking(__file__, cur_path)
sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/reasoning/", _import_structure)
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.reasoning.CategoryFuzz import CategoryUtils
import pandas as pd
@OPERATOR_REGISTRY.register()
class ReasoningCategoryDatasetEvaluator(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.logger.info(f'{self.__class__.__name__} initialized.')
self.information_name = "Category Information"
self.category_list = CategoryUtils().secondary_categories
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于统计数据集中的类别信息,包括主类别和次类别的分布情况。"
"它计算每个类别的样本数量,并返回类别分布的统计结果。\n"
"输入参数:\n"
"- input_primary_category_key:主类别字段名,默认为'primary_category'\n"
"- input_secondary_category_key:次类别字段名,默认为'secondary_category'\n"
"输出参数:\n"
"- 返回包含类别统计信息的字典,主类别作为键,值为包含该类别样本数量和次类别分布的字典"
)
elif lang == "en":
return (
"This operator analyzes category distribution in the dataset, including primary and secondary categories. "
"It counts the number of samples in each category and returns statistical results of category distribution.\n"
"Input Parameters:\n"
"- input_primary_category_key: Field name for primary category, default is 'primary_category'\n"
"- input_secondary_category_key: Field name for secondary category, default is 'secondary_category'\n\n"
"Output Parameters:\n"
"- Returns a dictionary containing category statistics, with primary categories as keys and values as dictionaries "
"containing sample counts and secondary category distribution"
)
else:
return (
"CategoryInfo analyzes and reports the distribution of primary and secondary categories in the dataset."
)
def get_category_info(self, samples, input_primary_category_key = "primary_category", input_secondary_category_key = "secondary_category"):
primary_categories = [sample.get(input_primary_category_key, '') for sample in samples]
secondary_categories = [sample.get(input_secondary_category_key, '') for sample in samples]
primary_categories_count = pd.Series(primary_categories).value_counts().to_dict()
secondary_categories_count = pd.Series(secondary_categories).value_counts().to_dict()
output = []
for primary in self.category_list:
js = {}
if primary not in primary_categories_count:
continue
js["primary_num"] = primary_categories_count[primary]
for secondary in self.category_list[primary]:
if secondary not in secondary_categories_count:
continue
js[secondary] = secondary_categories_count[secondary]
output[primary] = js
self.logger.info(f"Category information: {output}")
return output
def run(self,storage: DataFlowStorage, input_primary_category_key: str = "primary_category", input_secondary_category_key: str = "secondary_category"):
self.input_primary_category_key = input_primary_category_key
self.input_secondary_category_key = input_secondary_category_key
dataframe = storage.read("dataframe")
if self.input_primary_category_key not in dataframe.columns or self.input_secondary_category_key not in dataframe.columns:
self.logger.error(f"Input keys {self.input_primary_category_key} or {self.input_secondary_category_key} not found in dataframe columns.")
return {}
samples = dataframe.to_dict(orient='records')
category_info = self.get_category_info(samples, self.input_primary_category_key, self.input_secondary_category_key)
return category_info
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
@OPERATOR_REGISTRY.register()
class ReasoningDifficultyDatasetEvaluator(OperatorABC):
def __init__(self):
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.logger.info(f'{self.__class__.__name__} initialized.')
self.information_name = "Difficulty Information"
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于统计数据集中的难度信息,计算不同难度级别的样本数量分布。"
"它统计每个难度级别的样本数量,并返回难度分布的统计结果。\n"
"输入参数:\n"
"- input_diffulty_key:难度分数字段名,默认为'difficulty_score'\n"
"输出参数:\n"
"- 返回包含难度统计信息的字典,难度级别作为键,值为该难度级别的样本数量"
)
elif lang == "en":
return (
"This operator analyzes difficulty distribution in the dataset, calculating the number of samples at different difficulty levels. "
"It counts samples at each difficulty level and returns statistical results of difficulty distribution.\n"
"Input Parameters:\n"
"- input_diffulty_key: Field name for difficulty score, default is 'difficulty_score'\n\n"
"Output Parameters:\n"
"- Returns a dictionary containing difficulty statistics, with difficulty levels as keys and sample counts as values"
)
else:
return (
"DifficultyInfo analyzes and reports the distribution of difficulty levels in the dataset."
)
def get_category_info(self, samples, input_diffulty_key="difficulty_score"):
diffultys = [sample.get(input_diffulty_key, 'null') for sample in samples]
diffultys_count = pd.Series(diffultys).value_counts().to_dict()
self.logger.info(f"Difficulty information: {diffultys_count}")
return diffultys_count
def run(self,storage: DataFlowStorage, input_diffulty_key: str = "difficulty_score"):
self.input_diffulty_key = input_diffulty_key
dataframe = storage.read("dataframe")
if self.input_diffulty_key not in dataframe.columns:
self.logger.error(f"Input key {self.input_diffulty_key} not found in dataframe columns.")
return {}
samples = dataframe.to_dict(orient='records')
category_info = self.get_category_info(samples, self.input_diffulty_key)
return category_info
\ No newline at end of file
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core.prompt import prompt_restrict
from dataflow.utils.reasoning.CategoryFuzz import CategoryUtils
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.math import MathQuestionCategoryPrompt
import pandas as pd
import json
import re
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
@prompt_restrict(
MathQuestionCategoryPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningQuestionCategorySampleEvaluator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None):
"""
Initialize the ReasoningCategoryDatasetEvaluator with the provided configuration.
"""
self.logger = get_logger()
self.prompts = MathQuestionCategoryPrompt()
self.llm_serving = llm_serving
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对用户问题进行多级分类(主分类和子分类)。"
"通过大语言模型对输入问题进行语义分析,输出分类编码结果。\n\n"
"输入参数:\n"
"- db_port/db_name/table_name:数据库连接参数(存储模式)\n"
"- input_file/output_file:文件路径(文件模式)\n"
"- input_key:输入数据中问题字段的键名\n"
"- generator_type:模型调用方式(aisuite/request)\n\n"
"输出参数:\n"
"- classification_result:包含主分类和子分类的编码结果"
)
elif lang == "en":
return (
"Performs hierarchical classification (primary and secondary) on user questions. "
"Utilizes LLM for semantic analysis and outputs category codes.\n\n"
"Input Parameters:\n"
"- db_port/db_name/table_name: Database connection params (storage mode)\n"
"- input_file/output_file: File paths (file mode)\n"
"- input_key: Key for question field in input data\n"
"- generator_type: Model invocation method (aisuite/request)\n\n"
"Output Parameters:\n"
"- classification_result: Combined category code"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
"""
Reformat the prompts in the dataframe to generate questions.
"""
# Check if input_key is in the dataframe
formatted_prompts = []
for text in dataframe[self.input_key]:
used_prompt = self.prompts.build_prompt(text)
formatted_prompts.append(used_prompt.strip())
return formatted_prompts
def run(self, storage: DataFlowStorage, input_key:str = "instruction", output_key:str="question_category") -> None:
"""
Run the question category classification process.
"""
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._reformat_prompt(dataframe)
responses = self.llm_serving.generate_from_input(formatted_prompts)
for (idx, row), classification_str in zip(dataframe.iterrows(), responses):
try:
if not classification_str:
raise ValueError("空字符串")
# 去除 Markdown 代码块包裹
cleaned_str = re.sub(r"^```json\s*|\s*```$", "", classification_str.strip(), flags=re.DOTALL)
# 去除非 ASCII 字符(可选)
cleaned_str = re.sub(r"[^\x00-\x7F]+", "", cleaned_str)
classification = json.loads(cleaned_str)
primary_raw = classification.get("primary_category", "")
secondary_raw = classification.get("secondary_category", "")
category_info = CategoryUtils().normalize_categories(raw_primary=primary_raw, raw_secondary=secondary_raw)
dataframe.at[idx, "primary_category"] = category_info["primary_category"]
dataframe.at[idx, "secondary_category"] = category_info["secondary_category"]
except json.JSONDecodeError:
self.logger.warning(f"[警告] JSON 解析失败,收到的原始数据: {repr(classification_str)}")
except Exception as e:
self.logger.error(f"[错误] 解析分类结果失败: {e}")
self.logger.debug(f"[DEBUG] 原始字符串:{repr(classification_str)}")
output_file = storage.write(dataframe)
self.logger.info(f"Classification results saved to {output_file}")
return ["primary_category", "secondary_category"]
\ No newline at end of file
from dataflow.prompts.reasoning.math import MathQuestionDifficultyPrompt
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core.prompt import prompt_restrict
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
import pandas as pd
import re
@prompt_restrict(
MathQuestionDifficultyPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningQuestionDifficultySampleEvaluator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None):
"""
Initialize the ReasoningCategoryDatasetEvaluator with the provided configuration.
"""
self.logger = get_logger()
self.prompts = MathQuestionDifficultyPrompt()
self.llm_serving = llm_serving
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于评估问题的难度等级。"
"通过大语言模型分析问题复杂度,输出1-10级的难度评分。\n\n"
"输入参数:\n"
"- eval_stage:评估阶段标识\n"
"- read_min/max_score:分数过滤阈值\n"
"- 其他参数同ReasoningCategoryDatasetEvaluator\n\n"
"输出参数:\n"
"- difficulty_score:数值型难度评分(1-10)"
)
elif lang == "en":
return (
"Evaluates question difficulty level using LLM analysis. "
"Outputs numerical difficulty score from 1 to 10.\n\n"
"Input Parameters:\n"
"- eval_stage: Evaluation stage identifier\n"
"- read_min/max_score: Score filtering thresholds\n"
"- Other params same as ReasoningCategoryDatasetEvaluator\n\n"
"Output Parameters:\n"
"- difficulty_score: Numerical difficulty rating (1-10)"
)
def _validate_dataframe(self, dataframe: pd.DataFrame, input_key: str = "instruction", output_key: str = "difficulty_score"):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe, input_key: str = "instruction") -> list:
"""
Reformat the prompts in the dataframe to generate questions.
"""
formatted_prompts = []
for i, text in enumerate(dataframe[input_key]):
if text is not None:
used_prompt = self.prompts.build_prompt(text)
else:
used_prompt = None
formatted_prompts.append(used_prompt.strip())
return formatted_prompts
def run(self, storage:DataFlowStorage, input_key: str, output_key:str="difficulty_score") -> None:
"""
Run the question difficulty classification process.
"""
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(
dataframe,
input_key=self.input_key,
output_key=self.output_key
)
formatted_prompts = self._reformat_prompt(dataframe, input_key=self.input_key)
responses = self.llm_serving.generate_from_input(user_inputs=formatted_prompts)
rating_scores = []
for response in responses:
match = re.search(r'Rating:\s*((\d+\.\d+)|\d+)', response)
if match:
score_str = match.group(1).rstrip('.')
try:
score = float(score_str)
except ValueError:
score = -1
else:
score = -1
rating_scores.append(score)
dataframe[output_key] = rating_scores
output_file = storage.write(dataframe)
self.logger.info(f"Classification results saved to {output_file}")
return [output_key]
\ No newline at end of file
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.math import MathQuestionEvaluatorPrompt
from typing import Union
import pandas as pd
import json
import re
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
@prompt_restrict(
MathQuestionEvaluatorPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningQuestionSolvableSampleEvaluator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC = None, prompt_template: Union[MathQuestionEvaluatorPrompt, DIYPromptABC] = None):
"""
Initialize the ReasoningCategoryDatasetEvaluator with the provided configuration.
"""
self.logger = get_logger()
if prompt_template is None:
prompt_template = MathQuestionEvaluatorPrompt()
self.prompt_template = prompt_template
self.llm_serving = llm_serving
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对用户问题进行多级分类(主分类和子分类)。"
"通过大语言模型对输入问题进行语义分析,输出分类编码结果。\n\n"
"输入参数:\n"
"- db_port/db_name/table_name:数据库连接参数(存储模式)\n"
"- input_file/output_file:文件路径(文件模式)\n"
"- input_key:输入数据中问题字段的键名\n"
"- generator_type:模型调用方式(aisuite/request)\n\n"
"输出参数:\n"
"- classification_result:包含主分类和子分类的编码结果"
)
elif lang == "en":
return (
"Performs hierarchical classification (primary and secondary) on user questions. "
"Utilizes LLM for semantic analysis and outputs category codes.\n\n"
"Input Parameters:\n"
"- db_port/db_name/table_name: Database connection params (storage mode)\n"
"- input_file/output_file: File paths (file mode)\n"
"- input_key: Key for question field in input data\n"
"- generator_type: Model invocation method (aisuite/request)\n\n"
"Output Parameters:\n"
"- classification_result: Combined category code"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
problem = dataframe[self.input_key].tolist()
system_prompt = self.prompt_template.build_system_prompt()
prompts = [self.prompt_template.build_prompt(p) for p in problem]
return system_prompt, prompts
def run(self, storage: DataFlowStorage, input_key: str, output_key: str):
"""
Run the question generation process.
"""
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
sys_prompts, user_prompts = self._reformat_prompt(dataframe)
responses = self.llm_serving.generate_from_input(user_prompts, sys_prompts)
dataframe[f"{output_key}"] = responses
self.logger.info(f"Generated questions for {output_key}")
output_file = storage.write(dataframe)
self.logger.info(f"Generated questions saved to {output_file}")
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
from transformers import AutoTokenizer
@OPERATOR_REGISTRY.register()
class ReasoningTokenDatasetEvaluator(OperatorABC):
def __init__(self, model_name_or_path: str):
self.logger = get_logger()
self.logger.info(f'Initializing {self.__class__.__name__}...')
self.logger.info(f'{self.__class__.__name__} initialized.')
self.information_name = "Token Information"
self.model_name_or_path = model_name_or_path
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于统计数据集中问题和回答的token信息,包括token数量的最小值、最大值、平均值和中位数等统计指标。"
"它使用指定的tokenizer对文本进行编码,并计算token长度的分布情况。\n"
"输入参数:\n"
"- input_question_key:问题文本字段名\n"
"- input_answer_key:回答文本字段名\n"
"- model_name_or_path:tokenizer模型名称或路径\n"
"输出参数:\n"
"- 返回包含token统计信息的字典,包括问题和回答的token数量的零值计数、最小值、最大值、平均值和中位数"
)
elif lang == "en":
return (
"This operator analyzes token information for questions and answers in the dataset, including statistical metrics "
"such as minimum, maximum, mean, and median token counts. It encodes text using the specified tokenizer and calculates "
"token length distribution.\n"
"Input Parameters:\n"
"- input_question_key: Field name for question text\n"
"- input_answer_key: Field name for answer text\n"
"- model_name_or_path: Tokenizer model name or path\n\n"
"Output Parameters:\n"
"- Returns a dictionary containing token statistics, including zero count, minimum, maximum, mean, and median token counts "
"for both questions and answers"
)
else:
return (
"ToKenInfo analyzes and reports token length statistics for questions and answers in the dataset using a specified tokenizer."
)
def get_token_info(self, samples, input_question_key, input_answer_key, model_name_or_path):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
questions = [sample.get(input_question_key, '') or '' for sample in samples]
answers = [sample.get(input_answer_key, '') or '' for sample in samples]
questions_tokens_length = [len(tokenizer.encode(question, add_special_tokens=False)) for question in questions]
answers_tokens_length = [len(tokenizer.encode(answer, add_special_tokens=False)) for answer in answers]
# count zeros in questions_tokens_length and answers_tokens_length
questions_zeros_count = questions_tokens_length.count(0)
answers_zeros_count = answers_tokens_length.count(0)
# count min,max,mean, median of questions_tokens_length and answers_tokens_length
questions_min = min(questions_tokens_length) if questions_tokens_length else 0
questions_max = max(questions_tokens_length) if questions_tokens_length else 0
questions_mean = sum(questions_tokens_length) / len(questions_tokens_length) if questions_tokens_length else 0
questions_median = sorted(questions_tokens_length)[len(questions_tokens_length) // 2] if questions_tokens_length else 0
answers_min = min(answers_tokens_length) if answers_tokens_length else 0
answers_max = max(answers_tokens_length) if answers_tokens_length else 0
answers_mean = sum(answers_tokens_length) / len(answers_tokens_length) if answers_tokens_length else 0
answers_median = sorted(answers_tokens_length)[len(answers_tokens_length) // 2] if answers_tokens_length else 0
token_info = {
"questions_zeros_count": questions_zeros_count,
"answers_zeros_count": answers_zeros_count,
"questions_min": questions_min,
"questions_max": questions_max,
"questions_mean": questions_mean,
"questions_median": questions_median,
"answers_min": answers_min,
"answers_max": answers_max,
"answers_mean": answers_mean,
"answers_median": answers_median
}
self.logger.info(f"Token information: {token_info}")
return token_info
def run(self,storage: DataFlowStorage, input_question_key: str, input_answer_key: str):
self.input_question_key = input_question_key
self.input_answer_key = input_answer_key
dataframe = storage.read("dataframe")
if self.input_question_key not in dataframe.columns:
self.logger.error(f"Input key {self.input_question_key} not found in dataframe columns.")
return {}
if self.input_answer_key not in dataframe.columns:
self.logger.warning(f"Input key {self.input_answer_key} not found in dataframe columns")
samples = dataframe.to_dict(orient='records')
token_info = self.get_token_info(samples, self.input_question_key, self.input_answer_key, self.model_name_or_path)
return token_info
\ No newline at end of file
import numpy as np
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
@OPERATOR_REGISTRY.register()
class ReasoningAnswerFormatterFilter(OperatorABC):
def __init__(self):
self.logger = get_logger()
def is_valid_answer(answer: str) -> bool:
# check final answer in \boxed{} or not
# if not re.search(r'\\boxed{.*}', answer):
# return False
return True
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于检查答案格式是否符合规范,主要验证数学答案是否包含正确的\\boxed{}标记。\n\n"
"输入参数:\n"
"- input_key:输入字段名\n"
"- result_key:结果字段名\n\n"
"输出参数:\n"
"- 通过格式检查返回1,否则返回0"
)
elif lang == "en":
return (
"This operator validates answer formatting, specifically checking for correct \\boxed{} notation.\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the answer\n"
"- result_key: Output result field name\n\n"
"Output Parameters:\n"
"- Returns 1 for valid format, 0 otherwise"
)
else:
return "AnswerFormatterFilter validates mathematical answer formatting"
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = []
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
self.logger.error(f"Missing required column(s): {missing}")
if conflict:
self.logger.error(f"The following column(s) already exist and would be overwritten: {conflict}")
missing_keys = [key for key in required_keys if key not in dataframe.columns]
if missing_keys:
self.logger.error(f"The following required columns are missing from the dataframe: {missing_keys}")
def run(
self,
storage:DataFlowStorage,
input_key: str = "generated_cot",
) -> list:
'''
Execute the answer format filter process
'''
self.input_key = input_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
indexes = np.zeros(len(dataframe)).astype(int)
for i, item in dataframe.iterrows():
answer = item[self.input_key]
if ReasoningAnswerFormatterFilter.is_valid_answer(answer):
indexes[i] = 1
dataframe = dataframe[np.array(indexes) == 1]
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [self.input_key,]
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from dataflow.utils.storage import DataFlowStorage
from dataflow import get_logger
from dataflow.core import OperatorABC
from typing import Literal
from math_verify import parse, verify
import pandas as pd
@OPERATOR_REGISTRY.register()
class ReasoningAnswerGroundTruthFilter(OperatorABC):
def __init__(self,
compare_method: Literal["math_verify", "exact"] = "math_verify"):
name2compare = {
'exact': self.exact_compare,
'math_verify': self.math_verify_compare
}
self.compare = name2compare[compare_method]
unit_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_manager)
self.answer_extractor = AnswerExtractor(string_cleaner)
self.logger = get_logger()
def exact_compare(self, answer, ground_truth):
return str(answer) == str(ground_truth)
def math_verify_compare(self, answer, ground_truth):
try:
return verify(parse(str(ground_truth)), parse(str(answer)))
except:
try:
return verify(parse(ground_truth), parse(answer))
except:
return False
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对比预测答案与标准答案的匹配度,支持精确匹配和数学验证两种方式。\n\n"
"输入参数:\n"
"- input_test_answer_key:预测答案字段名\n"
"- input_gt_answer_key:标准答案字段名\n"
"- compare_method:比较方法(exact/math_verify)\n\n"
"输出参数:\n"
"- 匹配成功返回1,否则返回0"
)
elif lang == "en":
return (
"This operator compares predicted answers against ground truth using exact or mathematical verification.\n\n"
"Input Parameters:\n"
"- test_answer_key: Predicted answer field\n"
"- gt_answer_key: Ground truth field\n"
"- compare_method: Comparison method (exact/math_verify)\n\n"
"Output Parameters:\n"
"- Returns 1 for matches, 0 otherwise"
)
else:
return "AnswerGroundTruthFilter performs answer validation"
def run(
self,
storage:DataFlowStorage,
input_test_answer_key: str = "generated_cot",
input_gt_answer_key: str = "golden_answer"
) -> list:
self.test_answer_key = input_test_answer_key
self.gt_answer_key = input_gt_answer_key
dataframe = storage.read("dataframe")
output = []
answers = dataframe[self.test_answer_key]
ground_truths = dataframe[self.gt_answer_key]
for i in range(len(answers)):
final_answer = self.answer_extractor.extract_answer(answers[i], None)
if self.compare(final_answer, ground_truths[i]):
output.append(dataframe.iloc[i])
output = pd.DataFrame(output)
output_file = storage.write(output)
self.logger.info(f"Filtered data saved to {output_file}")
return [self.test_answer_key, self.gt_answer_key]
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import LLMServingABC
from dataflow.prompts.model_evaluation.general import AnswerJudgePromptQuestion, AnswerJudgePrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
import re
import pandas as pd
import numpy as np
from typing import Union
@prompt_restrict(
AnswerJudgePromptQuestion,
AnswerJudgePrompt,
)
@OPERATOR_REGISTRY.register()
class ReasoningAnswerModelJudgeFilter(OperatorABC):
def __init__(self,
system_prompt: str = "You are a helpful assistant specialized in evaluating answer correctness.",
llm_serving: LLMServingABC = None,
prompt_template: Union[AnswerJudgePromptQuestion,AnswerJudgePrompt, DIYPromptABC] = AnswerJudgePromptQuestion,
keep_all_samples: bool = False, # 新增参数,控制是否保留所有样本
):
self.logger = get_logger()
if prompt_template is None:
prompt_template = AnswerJudgePrompt()
self.prompt_template = prompt_template
self.system_prompt = system_prompt
self.llm_serving = llm_serving
self.empty_responses_count = 0 # 添加空响应计数器
self.keep_all_samples = keep_all_samples # 保存参数
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对答案进行正确性评判,通过比较当前答案与参考答案的语义一致性,判断答案是否正确。"
"调用大语言模型进行语义理解和判断,最终返回每个答案是否正确的二分类结果。\n"
"输入参数:\n"
"- system_prompt:系统提示词,用于定义模型行为\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- prompt_template:提示模板对象,用于构建评判提示词\n"
"- keep_all_samples:是否保留所有样本,默认为False(仅保留正确答案)\n"
"- question_key:问题字段名,默认为'question'\n"
"- answer_key:当前答案字段名,默认为'answer'\n"
"- reference_key:参考答案字段名,默认为'reference_answer'\n"
"输出参数:\n"
"- DataFrame,包含原始数据和判断结果(answer_match_result字段)\n"
"- 如果keep_all_samples为False,则仅保留判断结果为True的行\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator evaluates the correctness of answers by comparing the semantic consistency between "
"the current answer and the reference answer. It uses a large language model for semantic understanding "
"and judgment, ultimately returning a binary classification result for each answer.\n"
"Input Parameters:\n"
"- system_prompt: System prompt to define model behavior\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- prompt_template: Prompt template object for constructing evaluation prompts\n"
"- keep_all_samples: Whether to keep all samples, default is False (only keep correct answers)\n"
"- question_key: Field name for questions, default is 'question'\n"
"- answer_key: Field name for current answers, default is 'answer'\n"
"- reference_key: Field name for reference answers, default is 'reference_answer'\n\n"
"Output Parameters:\n"
"- DataFrame containing original data and judgment results (answer_match_result field)\n"
"- If keep_all_samples is False, only rows with True judgment results are retained\n"
"- List containing input field names for subsequent operator reference"
)
else:
return (
"AnswerJudge evaluates answer correctness by comparing semantic consistency with reference answers using LLM."
)
def ResolveResponse(self, response):
# 检查空响应
if response is None or (isinstance(response, str) and response.strip() == ''):
self.empty_responses_count += 1
return False
try:
pattern = re.compile(r'"judgement_result"\s*:\s*(true|false)', re.IGNORECASE)
match = pattern.search(response)
result_value = None
if match:
result_value = match.group(1).lower()
else:
# 备用解析逻辑,检查响应中是否包含true或false
if "true" in response.lower():
result_value = "true"
else:
result_value = "false"
if result_value == "true":
return True
else:
return False
except Exception as e:
self.logger.error(f"Response format error: {response}. Error: {e}")
return False
def run(self, storage: DataFlowStorage, input_question_key: str = "question", input_answer_key: str = "answer", input_reference_key: str = "reference_answer") -> list:
self.question_key = input_question_key
self.answer_key = input_answer_key
self.reference_key = input_reference_key
dataframe = storage.read("dataframe")
# 检查必要的列是否存在
required_columns = [input_question_key, input_answer_key, input_reference_key]
for column in required_columns:
if column not in dataframe.columns:
self.logger.error(f"Required column '{column}' not found in dataframe")
return required_columns
# 检查参考答案是否为空或不存在
empty_reference_mask = dataframe[input_reference_key].isna() | (dataframe[input_reference_key] == '')
skipped_rows = dataframe[empty_reference_mask]
valid_rows = dataframe[~empty_reference_mask]
# 记录跳过的行数
skipped_count = len(skipped_rows)
# 初始化结果列,默认所有行为False
dataframe['answer_match_result'] = False
if len(valid_rows) == 0:
self.logger.warning("No valid samples with reference answers found. All samples skipped.")
if self.keep_all_samples:
output_file = storage.write(dataframe) # 保留所有行,但answer_match_result都为False
else:
output_file = storage.write(pd.DataFrame(columns=dataframe.columns)) # 不保留任何行
self.logger.info(f"Dataframe saved to {output_file}. Skipped {skipped_count} samples due to missing reference answers.")
return required_columns + ['answer_match_result']
# 只对有参考答案的行构建提示词并调用LLM
inputs = [self.prompt_template.build_prompt(
question=row[input_question_key],
answer=row[input_answer_key],
reference_answer=row[input_reference_key]
) for _, row in valid_rows.iterrows()]
responses = self.llm_serving.generate_from_input(user_inputs=inputs, system_prompt=self.system_prompt)
results = [self.ResolveResponse(response) for response in responses]
# 创建结果掩码,与valid_rows长度相同
result_mask = np.array(results, dtype=bool)
# 更新有效行的answer_match_result
valid_indices = valid_rows.index
for i, idx in enumerate(valid_indices):
dataframe.at[idx, 'answer_match_result'] = results[i]
# 根据keep_all_samples决定是否保留所有样本
if self.keep_all_samples:
# 保留所有样本,包括不匹配的和没有参考答案的
final_dataframe = dataframe
else:
# 只保留匹配的样本
final_dataframe = dataframe[dataframe['answer_match_result'] == True]
output_file = storage.write(final_dataframe)
# 记录统计信息
total_samples = len(dataframe)
valid_samples = len(valid_rows)
matched_samples = sum(results)
accuracy = matched_samples / valid_samples if valid_samples > 0 else 0
self.logger.info(f"Processed answers saved to {output_file}.")
self.logger.info(f"Total samples: {total_samples}, Valid samples: {valid_samples}, Skipped samples: {skipped_count}")
self.logger.info(f"Matched answers: {matched_samples}, Accuracy: {accuracy:.2%}")
self.logger.info(f"Output samples: {len(final_dataframe)}")
# 记录空响应数量并重置计数器
if self.empty_responses_count > 0:
self.logger.error(f"Found {self.empty_responses_count} empty responses during evaluation.")
self.empty_responses_count = 0
return required_columns + ['answer_match_result']
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment