"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e9e5f61c45e13f9b87be985ae735de8c217e9915"
Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
import numpy as np
import pandas as pd
from tqdm import tqdm
import re
@OPERATOR_REGISTRY.register()
class ReasoningAnswerNgramFilter(OperatorABC):
def __init__(self,
min_score: float = 0.1,
max_score: float = 1.0,
ngrams: int = 5):
self.min_score = min_score
self.max_score = max_score
self.ngrams = ngrams
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子基于n-gram重复率过滤答案,检测回答中的重复模式。\n\n"
"输入参数:\n"
"- min_score:最小可接受分数\n"
"- max_score:最大可接受分数\n"
"- ngrams:n-gram大小\n\n"
"输出参数:\n"
"- 分数在范围内返回1,否则返回0"
)
elif lang == "en":
return (
"This filter detects repetitive patterns using n-gram repetition scores.\n\n"
"Input Parameters:\n"
"- min_score: Minimum acceptable score\n"
"- max_score: Maximum acceptable score\n"
"- ngrams: Size of n-grams\n\n"
"Output Parameters:\n"
"- Returns 1 if score is within range, 0 otherwise"
)
else:
return "AnswerNgramFilter detects answer repetition"
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.question_key, self.answer_key]
forbidden_keys = []
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
self.logger.error(f"Missing required column(s): {missing}")
if conflict:
self.logger.error(f"The following column(s) already exist and would be overwritten: {conflict}")
missing_keys = [key for key in required_keys if key not in dataframe.columns]
if missing_keys:
self.logger.error(f"The following required columns are missing from the dataframe: {missing_keys}")
def run(
self,
storage: DataFlowStorage,
input_question_key: str = "instruction",
input_answer_key: str = "generated_cot"
) -> list:
self.question_key = input_question_key
self.answer_key = input_answer_key
dataframe = storage.read("dataframe")
self.logger.info(f"Found {len(dataframe)} rows in the dataframe, DataFrame columns: {dataframe.columns}")
missing_answer_logged = False
scores = []
for _, sample in dataframe.iterrows():
try:
answer = sample.get(self.question_key, "") + sample.get(self.answer_key, "")
except Exception as e:
if not missing_answer_logged:
self.logger.info(f"*** Only question is available ***")
missing_answer_logged = True
answer = sample.get(self.question_key, "")
content = answer.lower()
content = re.sub(r'[^\w\s]', '', content)
words = content.split()
ngrams = [' '.join(words[i:i + self.ngrams]) for i in range(len(words) - (self.ngrams - 1))]
unique_ngrams = set(ngrams)
total_ngrams = len(ngrams)
unique_ngrams_count = len(unique_ngrams)
repetition_score = unique_ngrams_count / total_ngrams if total_ngrams > 0 else 0.0
scores.append(repetition_score)
indexes = np.array([self.min_score <= s <= self.max_score for s in scores])
dataframe = dataframe[indexes]
self.logger.info(f"Filtered down to {len(dataframe)} rows with repetition score in [{self.min_score}, {self.max_score}]")
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [input_question_key, input_answer_key]
\ No newline at end of file
from dataflow import get_logger
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
import pandas as pd
@OPERATOR_REGISTRY.register()
class ReasoningAnswerPipelineRootFilter(OperatorABC):
def __init__(self):
self.logger = get_logger()
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"答案处理流程根节点,负责将输入数据根据有无真实标签GT分发到不同处理分支。\n\n"
"输入参数:\n"
"- input_file:输入文件路径\n"
"- output_dir:输出目录路径\n"
"- branch_config:分支配置参数\n"
"- parallel_workers:并行工作线程数\n\n"
"输出参数:\n"
"- 多个输出文件路径(根据分支配置生成)"
)
elif lang == "en":
return (
"Root node of answer processing pipeline, distributes input data to different processing branches.\n\n"
"Input Parameters:\n"
"- input_file: Input file path\n"
"- output_dir: Output directory path\n"
"- branch_config: Branch configuration parameters\n"
"- parallel_workers: Number of parallel workers\n\n"
"Output Parameters:\n"
"- Multiple output file paths (generated based on branch config)"
)
else:
return "AnswerPipelineRoot routes data to different processing branches."
def run(self, storage: DataFlowStorage, input_answer_key: str = "output", input_gt_key: str = "golden_answer"):
self.input_answer_key = input_answer_key
self.input_gt_key = input_gt_key
df = storage.read("dataframe")
if not self.input_gt_key or self.input_gt_key not in df.columns:
self.logger.warning("No valid gt key in input file, copy input file to output file without gt")
return
# 初始化答案提取器
if self.input_answer_key in df.columns:
unit_text_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_text_manager)
answer_extractor = AnswerExtractor(string_cleaner)
def extract_gt(answer, gt):
try:
if gt != "" and not pd.isna(gt):
return gt
else:
if pd.isna(answer) or answer == "":
return None
else:
return answer_extractor.extract_answer(answer,None,True)
except Exception as e:
self.logger.error(f"Error in extract_gt: {e}", exc_info=True)
return None
# 使用 apply 遍历 DataFrame, 避免显式循环索引问题
df[self.input_gt_key] = df.apply(lambda row: extract_gt(row[self.input_answer_key],
row[self.input_gt_key]),
axis=1)
# 拆分有gt和无gt的 DataFrame
df_with_gt = df[(df[self.input_gt_key].notna()) & (df[self.input_gt_key] != "")]
df_without_gt = df[(df[self.input_gt_key].isna()) | (df[self.input_gt_key] == "")].copy()
df_without_gt[self.input_gt_key] = None
# 输出结果
if len(df_with_gt) > 0:
output_file_gt = storage.write(df_with_gt)
self.logger.info(f"output {df_with_gt.shape[0]} rows with gt to {output_file_gt}")
if len(df_without_gt) > 0:
output_file_without_gt = storage.write(df_without_gt)
self.logger.info(f"output {df_without_gt.shape[0]} rows without gt to {output_file_without_gt}")
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from transformers import AutoTokenizer
from tqdm import tqdm
import pandas as pd
@OPERATOR_REGISTRY.register()
class ReasoningAnswerTokenLengthFilter(OperatorABC):
def __init__(self,
max_answer_token_length: int = 8192,
tokenizer_dir: str = "Qwen/Qwen2.5-0.5B-Instruct"):
self.max_answer_token_length = max_answer_token_length
self.tokenizer_dir = tokenizer_dir
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir)
self.logger = get_logger()
self.empty_count = 0 # 添加空值计数器
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子根据token数量过滤过长的答案。\n\n"
"输入参数:\n"
"- max_answer_token_length:最大token数\n"
"- tokenizer_dir:分词器路径\n"
"- read_min/max_score:分数范围\n\n"
"输出参数:\n"
"- 长度合规返回1,否则返回0"
)
elif lang == "en":
return (
"Filters answers exceeding specified token length limit.\n\n"
"Input Parameters:\n"
"- max_answer_token_length: Maximum allowed tokens\n"
"- tokenizer_dir: Tokenizer directory\n"
"- read_min/max_score: Score range\n\n"
"Output Parameters:\n"
"- Returns 1 if within limit, 0 otherwise"
)
else:
return "AnswerTokenLengthFilter enforces answer length constraints"
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = []
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
self.logger.error(f"Missing required column(s): {missing}")
if conflict:
self.logger.error(f"The following column(s) already exist and would be overwritten: {conflict}")
missing_keys = [key for key in required_keys if key not in dataframe.columns]
if missing_keys:
self.logger.error(f"The following required columns are missing from the dataframe: {missing_keys}")
def run(
self,
storage:DataFlowStorage,
input_key: str = "generated_cot"
) -> list:
dataframe = storage.read("dataframe")
self.input_key = input_key
self.logger.info(f"Found {len(dataframe)} rows in the dataframe")
self._validate_dataframe(dataframe)
def get_token_count(input_string):
# 检查None或空字符串
if input_string is None or (isinstance(input_string, str) and input_string.strip() == ''):
self.empty_count += 1
return self.max_answer_token_length + 1 # 确保被过滤
try:
tokens = self.tokenizer.encode(input_string, add_special_tokens=False)
return len(tokens)
except Exception as e:
self.logger.error(f"Token encoding error: {e}")
self.empty_count += 1
return self.max_answer_token_length + 1
output = []
for i, text in tqdm(enumerate(dataframe[self.input_key]), desc="Checking token lengths"):
is_valid = get_token_count(text) <= self.max_answer_token_length
if is_valid:
output.append(dataframe.iloc[i])
dataframe = pd.DataFrame(output)
output_file = storage.write(dataframe)
self.logger.info(f"Saved {len(dataframe)} filtered rows to {output_file}")
# 记录空值数量并重置计数器
if self.empty_count > 0:
self.logger.warning(f"Found {self.empty_count} empty or invalid entries during filtering")
self.empty_count = 0
return [self.input_key]
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import LLMServingABC
from dataflow.core.prompt import DIYPromptABC
from dataflow.prompts.reasoning.math import MathQuestionFilterPrompt
from dataflow.prompts.reasoning.general import GeneralQuestionFilterPrompt
from dataflow.prompts.reasoning.diy import DiyQuestionFilterPrompt
from dataflow.core.prompt import prompt_restrict
from typing import Union
import re
@prompt_restrict(
MathQuestionFilterPrompt,
GeneralQuestionFilterPrompt,
DiyQuestionFilterPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningQuestionFilter(OperatorABC):
def __init__(self,
system_prompt: str = "You are a helpful assistant.",
llm_serving: LLMServingABC = None,
prompt_template: Union[MathQuestionFilterPrompt, GeneralQuestionFilterPrompt, DiyQuestionFilterPrompt, DIYPromptABC] = MathQuestionFilterPrompt
):
self.logger = get_logger()
if prompt_template is None:
prompt_template = MathQuestionFilterPrompt()
self.prompt_template = prompt_template
self.system_prompt = system_prompt
self.llm_serving = llm_serving
self.empty_responses_count = 0 # 添加空响应计数器
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于对问题进行正确性检查,包括格式是否规范、语义是否合理、条件是否矛盾以及是否具备充分信息可解。"
"调用大语言模型依次执行四阶段判断,最终返回每个问题是否合格的二分类结果(保留合格样本)。\n"
"输入参数:\n"
"- system_prompt:系统提示词,用于定义模型行为\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- prompt_template:提示模板对象,用于构建检查提示词\n"
"- input_key:输入问题字段名,默认为'math_problem'\n"
"输出参数:\n"
"- 过滤后的DataFrame,仅保留判断结果为True的行\n"
"- 返回包含输入字段名的列表,用于后续算子引用"
)
elif lang == "en":
return (
"This operator checks the correctness of questions, including formatting, semantic validity, logical consistency, \n"
"and whether the problem is solvable. It performs a four-stage evaluation using a large language model and retains qualified samples.\n"
"Input Parameters:\n"
"- system_prompt: System prompt to define model behavior\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- prompt_template: Prompt template object for constructing check prompts\n"
"- input_key: Field name for input questions, default is 'math_problem'\n\n"
"Output Parameters:\n"
"- Filtered DataFrame containing only rows with True judgment results\n"
"- List containing input field name for subsequent operator reference"
)
else:
return (
"QuestionFilter performs correctness checking on questions using a multi-stage LLM evaluation and retains qualified samples."
)
def ResolveResponse(self, response):
# 检查空响应
if response is None or (isinstance(response, str) and response.strip() == ''):
self.empty_responses_count += 1
return False
try:
pattern = re.compile(r'"judgement_test"\s*:\s*(true|false)', re.IGNORECASE)
match = pattern.search(response)
test_value = None
if match:
test_value = match.group(1).lower()
else:
if "true" in response.lower():
test_value = "true"
else:
test_value = "false"
if test_value == "true":
return True
else:
return False
except Exception as e:
self.logger.error(f"Response format error for problem: {response}. Error: {e}")
return False
def run(self, storage: DataFlowStorage, input_key: str = "math_problem") -> list:
self.input_key = input_key
dataframe = storage.read("dataframe")
questions = dataframe[input_key]
inputs = [self.prompt_template.build_prompt(question) for question in questions]
responses = self.llm_serving.generate_from_input(user_inputs=inputs, system_prompt=self.system_prompt)
results = [self.ResolveResponse(response) for response in responses]
# 保留results为True的行
dataframe = dataframe[results]
output_file = storage.write(dataframe)
self.logger.info(f"Filtered questions saved to {output_file}")
# 记录空响应数量并重置计数器
if self.empty_responses_count > 0:
self.logger.error(f"Found {self.empty_responses_count} empty responses during filtering.")
self.empty_responses_count = 0
return [input_key,]
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from word2number import w2n
from tqdm import tqdm
import pandas as pd
import logging
import re
# The main class to manage the entire extraction process
@OPERATOR_REGISTRY.register()
class ReasoningAnswerExtractionQwenMathEvalGenerator(OperatorABC):
"""
A class to handle the process of extracting answers from a dataset.
"""
def __init__(self, dataset_name:str = None):
"""
Initializes the AnswerExtraction_QwenMathEval class.
"""
self.logger = get_logger()
self.data_name = dataset_name
# Initialize helpers
unit_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_manager)
self.answer_extractor = AnswerExtractor(string_cleaner)
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于从数学问题回答中提取规范化答案表达式,进行字符串清洗、单位处理和格式标准化。\n\n"
"输入参数:\n"
"- input_key:输入数据字段名\n"
"- answer_key:原始答案字段名\n"
"- output_key:处理后的答案字段名\n"
"- unit_texts:需要过滤的单位文本列表\n\n"
"输出参数:\n"
"- output_key:标准化后的数学表达式字段"
)
elif lang == "en":
return (
"This operator extracts and normalizes mathematical expressions from answers, "
"performing string cleaning, unit processing and format standardization.\n\n"
"Input Parameters:\n"
"- input_key: Input data field name\n"
"- answer_key: Raw answer field name\n"
"- output_key: Processed answer field name\n"
"- unit_texts: List of unit texts to filter\n\n"
"Output Parameters:\n"
"- output_key: Standardized mathematical expression field"
)
else:
return "AnswerExtraction_QwenMathEval performs mathematical answer normalization and standardization."
def run(self, storage: DataFlowStorage, input_key:str = "pseudo_correct_solution_example", output_key:str = "extraction"):
"""
Executes the answer extraction process.
"""
self.input_key, self.output_key = input_key, output_key
raw_dataframe = storage.read("dataframe")
key_list = raw_dataframe.columns.to_list()
if self.input_key not in key_list:
raise ValueError(f"input_key: {self.input_key} not found in dataframe columns.")
self.logger.info(f"Found {len(raw_dataframe)} rows.")
extractions = [
self.answer_extractor.extract_answer(resp, self.data_name)
for resp in tqdm(raw_dataframe[self.input_key], desc='Processing')
]
raw_dataframe[self.output_key] = extractions
output_file = storage.write(raw_dataframe)
self.logger.info(f"Extracted answers saved to {output_file}")
return [output_key]
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt
from dataflow.prompts.reasoning.general import GeneralAnswerGeneratorPrompt
from dataflow.prompts.reasoning.diy import DiyAnswerGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
import pandas as pd
from typing import Union
@prompt_restrict(
MathAnswerGeneratorPrompt,
GeneralAnswerGeneratorPrompt,
DiyAnswerGeneratorPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningAnswerGenerator(OperatorABC):
'''
Answer Generator is a class that generates answers for given questions.
'''
def __init__(self,
llm_serving: LLMServingABC,
prompt_template: Union[MathAnswerGeneratorPrompt, GeneralAnswerGeneratorPrompt, DiyAnswerGeneratorPrompt, DIYPromptABC] = MathAnswerGeneratorPrompt
):
self.logger = get_logger()
if prompt_template is None:
prompt_template = MathAnswerGeneratorPrompt()
self.prompts = prompt_template
self.llm_serving = llm_serving
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于为给定问题生成答案,调用大语言模型进行推理。\n"
"输入参数:\n"
"- llm_serving:LLM服务实例,用于生成答案\n"
"- prompt_template:提示模板对象,用于构建生成提示词\n"
"输出参数:\n"
"- output_key:生成的答案字段,默认'generated_cot'"
)
elif lang == "en":
return (
"This operator generates answers for given questions using LLMs for reasoning. \n"
"Input Parameters:\n"
"- llm_serving: LLM serving instance for answer generation\n"
"- prompt_template: Prompt template object for constructing generation prompts\n"
"Output Parameters:\n"
"- output_key: Generated answer field, default 'generated_cot'"
)
else:
return "AnswerGenerator produces answers for questions using large language models."
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [self.output_key]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
"""
Reformat the prompts in the dataframe to generate questions.
"""
questions = dataframe[self.input_key].tolist()
inputs = [self.prompts.build_prompt(question) for question in questions]
return inputs
def run(
self,
storage: DataFlowStorage,
input_key:str = "instruction",
output_key:str = "generated_cot"
):
'''
Runs the answer generation process, reading from the input file and saving results to output.
'''
self.input_key, self.output_key = input_key, output_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._reformat_prompt(dataframe)
answers = self.llm_serving.generate_from_input(formatted_prompts)
dataframe[self.output_key] = answers
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")
return [output_key]
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow import get_logger
import pandas as pd
@OPERATOR_REGISTRY.register()
class ReasoningPretrainFormatConvertGenerator(OperatorABC):
def __init__(self):
self.logger = get_logger()
def run(self,
storage: DataFlowStorage,
input_read_key_question: str = "question",
input_read_key_answer: str = "answer",
output_key: str = "text"
):
self.read_key_question = input_read_key_question
self.read_key_answer = input_read_key_answer
self.output_key = output_key
dataframe = storage.read("dataframe")
output_rows = dataframe.where(pd.notnull(dataframe), None).to_dict(orient="records")
output_1 = []
for row in output_rows:
cur_q = row.get(self.read_key_question) if row.get(self.read_key_question) is not None else ""
cur_a = row.get(self.read_key_answer) if row.get(self.read_key_answer) is not None else ""
output_1.append({
"text": cur_q + "\n" + cur_a,
})
output_file = storage.write(output_1)
self.logger.info(f"SFT to PT convertion results saved to {output_file}")
return [input_read_key_question, input_read_key_answer, output_key]
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于将SFT格式数据转换为预训练格式。\n\n"
"输入参数:\n"
"- read_key_question:问题字段名\n"
"- read_key_answer:答案字段名\n"
"- output_key:输出文本字段名\n\n"
"输出参数:\n"
"- output_key:输出文本字段名,包含问题和答案的拼接结果\n"
"- 输出文件:转换后的预训练格式数据文件路径"
)
elif lang == "en":
return (
"Converts SFT format data to pretraining format.\n\n"
"Input Parameters:\n"
"- read_key_question: Question field name\n"
"- read_key_answer: Answer field name\n"
"- output_key: Output text field name\n\n"
"Output Parameters:\n"
"- output_key: Output text field name containing concatenated question and answer\n"
"- Output file: Path to pretraining format data file"
)
else:
return "FormatConvert_SFT_to_Pretrain: SFT to Pretraining format converter"
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.utils.reasoning.AnswerExtraction import StringCleaner, UnitTextManager, AnswerExtractor
from dataflow.prompts.reasoning.math import MathAnswerGeneratorPrompt
from dataflow.core.prompt import prompt_restrict
from collections import defaultdict, Counter
import pandas as pd
@prompt_restrict(
MathAnswerGeneratorPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningPseudoAnswerGenerator(OperatorABC):
'''
Pseudo Answer Generator is a class that generates answers for given questions, then choose the most frequent answer.
'''
def __init__(self, llm_serving: LLMServingABC = None, max_times: int = 3):
self.logger = get_logger()
self.prompts = MathAnswerGeneratorPrompt()
self.llm_serving = llm_serving
self.max_times = max_times
def check_config(self):
required_keys = ["input_file", "output_file", "input_key", "output_key_answer", "output_key_answer_value", "output_key_solutions", "output_key_correct_solution_example", "max_times"]
missing_keys = [key for key in required_keys if key not in self.config]
if missing_keys:
raise ValueError(f"Missing required config keys: {missing_keys}")
def get_extractor(self):
unit_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_manager)
answer_extractor = AnswerExtractor(string_cleaner)
return answer_extractor
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子生成多个候选答案并通过统计选择最优解,实现伪答案生成。\n\n"
"输入参数:\n"
"- input_file:输入文件路径\n"
"- output_file:输出文件路径\n"
"- max_times:最大生成次数\n"
"- selection_mode:统计选择模式(frequency/consistency)\n\n"
"输出参数:\n"
"- final_answer:最终选择答案字段\n"
"- candidate_answers:候选答案列表字段"
)
elif lang == "en":
return (
"This operator generates multiple candidate answers and selects the optimal solution "
"through statistical analysis.\n\n"
"Input Parameters:\n"
"- input_file: Input file path\n"
"- output_file: Output file path\n"
"- max_times: Maximum generation times\n"
"- selection_mode: Statistical selection mode (frequency/consistency)\n\n"
"Output Parameters:\n"
"- final_answer: Selected answer field\n"
"- candidate_answers: Candidate answers list field"
)
else:
return "PseudoAnswerGenerator produces pseudo-answers through multi-round generation and selection."
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = [
self.output_key_answer,
self.output_key_answer_value,
self.output_key_solutions,
self.output_key_correct_solution_example,
]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
key_list = dataframe.columns.tolist()
raise ValueError(
f"read_key: {missing[0]} not found in the dataframe, "
f"please check the read_key: {key_list}"
)
if conflict:
key_list = dataframe.columns.tolist()
raise ValueError(
f"Found {conflict} in the dataframe, which leads to overwriting the existing column(s), "
f"please check the output_key: {key_list}"
)
def run(
self,
storage: DataFlowStorage,
input_key: str = "instruction",
output_key_answer: str = "pseudo_answers",
output_key_answer_value: str = "pseudo_answer_value",
output_key_solutions: str = "pseudo_solutions",
output_key_correct_solution_example: str = "pseudo_correct_solution_example",
):
self.input_key, self.output_key_answer, self.output_key_answer_value = input_key, output_key_answer, output_key_answer_value
self.output_key_solutions, self.output_key_correct_solution_example = output_key_solutions, output_key_correct_solution_example
self.extractor = self.get_extractor()
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
input_data_number = dataframe.shape[0]
user_prompts = dataframe[self.input_key].tolist()
answer_dict = defaultdict(list)
solution_dict = defaultdict(list)
self.logger.info(f"Generating answers for {len(user_prompts)} questions")
for i in range(self.max_times):
self.logger.info(f"Generating: {i+1} times")
solutions = self.llm_serving.generate_from_input(user_prompts)
answers = [self.extractor.extract_answer(solution, None) for solution in solutions]
for idx, answer in enumerate(answers):
answer_dict[idx].append(answer)
solution_dict[idx].append((answer, solutions[idx]))
self.logger.info(f"Generating final answers")
dataframe[self.output_key_answer] = dataframe.get(self.output_key_answer, None)
dataframe[self.output_key_solutions] = dataframe.get(self.output_key_solutions, None)
dataframe[self.output_key_correct_solution_example] = dataframe.get(self.output_key_correct_solution_example, None)
for key, value in answer_dict.items():
count = Counter(value)
final_answer = count.most_common(1)[0][0]
dataframe.at[int(key),self.output_key_answer] = value
dataframe.at[int(key),self.output_key_solutions] = final_answer
correct_contents = [content for ans, content in solution_dict[key] if ans == final_answer]
dataframe.at[int(key), self.output_key_solutions] = correct_contents
correct_solution_example = correct_contents[0] if correct_contents else None
dataframe.at[int(key), self.output_key_correct_solution_example] = correct_solution_example
dataframe.at[int(key), self.output_key_answer_value] = final_answer
# 过滤掉没有答案的行
dataframe = dataframe[dataframe[self.output_key_answer_value].notna()]
dataframe = dataframe[dataframe[self.output_key_correct_solution_example].notna()]
self.logger.info(f"Data number {input_data_number} -> {dataframe.shape[0]}")
output_file = storage.write(dataframe)
self.logger.info(f"PsedoAnswerGenerator's results saved to {output_file}")
return [output_key_answer, output_key_answer_value, output_key_solutions, output_key_correct_solution_example]
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.math import MathQuestionParallelFusionGeneratorPrompt,MathQuestionSequentialFusionGeneratorPrompt,MathQuestionConditionFusionGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from typing import Union
import pandas as pd
import random
@prompt_restrict(
MathQuestionParallelFusionGeneratorPrompt,
MathQuestionSequentialFusionGeneratorPrompt,
MathQuestionConditionFusionGeneratorPrompt,
)
@OPERATOR_REGISTRY.register()
class ReasoningQuestionFusionGenerator(OperatorABC):
def __init__(self,
num_prompts: int = 1,
llm_serving: LLMServingABC = None,
prompt_template: Union[MathQuestionParallelFusionGeneratorPrompt, MathQuestionSequentialFusionGeneratorPrompt, MathQuestionConditionFusionGeneratorPrompt, DIYPromptABC] = None
):
"""
Initialize the QuestionGenerator with the provided configuration.
"""
self.logger = get_logger()
if prompt_template is None:
prompt_template = MathQuestionParallelFusionGeneratorPrompt()
self.prompts = prompt_template
self.num_prompts = num_prompts
self.llm_serving = llm_serving
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于基于现有问题生成新问题。\n"
"输入参数:\n"
"- num_prompts:生成问题的数量,整数,范围1-5(含),默认1\n"
"- llm_serving:LLM服务实例,用于生成问题\n"
"- prompt_template:提示模板对象,用于构建生成提示词\n"
"输出参数:\n"
"- 原始输入列(由input_key指定):新增生成的问题\n"
"- Synth_or_Input:标识问题来源,'input'表示原始问题,'synth'表示生成的新问题"
)
elif lang == "en":
return (
"Generates new questions based on existing ones. \n"
"Input Parameters:\n"
"- num_prompts: Number of questions to generate per input, integer between 1-5 (inclusive), default 1\n"
"- llm_serving: LLM serving instance for question generation\n"
"- prompt_template: Prompt template object for constructing generation prompts\n"
"Output Parameters:\n"
"- Original input column (specified by input_key): Contains newly generated questions\n"
"- Synth_or_Input: Indicates question source, 'input' for original questions, 'synth' for generated questions"
)
elif lang == "en":
return (
"Generates new questions based on existing ones. "
"Produces 1-5 new questions per original question.\n"
"Input Parameters:\n"
"- eval_stage: Evaluation stage identifier\n"
"- read_min/max_score: Score filtering thresholds\n"
"- Other params same as base classifier\n"
"Output Parameters:\n"
"- generated_questions: List of newly generated questions"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = ["Synth_or_Input"]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
"""
Reformat the prompts in the dataframe to generate questions based on num_prompts.
"""
problem_1 = dataframe[self.input_problem_1].tolist()
problem_2 = dataframe[self.input_problem_2].tolist()
system_prompt = self.prompts.build_system_prompt()
prompts = [self.prompts.build_prompt(p1,p2) for p1,p2 in enumerate(zip(problem_1, problem_2))]
return system_prompt, prompts
def run(self, storage: DataFlowStorage, input_problem_1: str, input_problem_2: str, output_key: str):
"""
Run the question generation process.
"""
self.input_problem_1, self.input_problem_2 = input_problem_1, input_problem_2
dataframe = storage.read("dataframe")
for i in range(self.num_prompts):
sys_prompts, user_prompts = self._reformat_prompt(dataframe)
responses = self.llm_serving.generate_from_input(user_prompts, sys_prompts)
dataframe[f"{output_key}_question_{i}"] = responses
self.logger.info(f"Generated questions for {output_key}_{i}")
output_file = storage.write(dataframe)
self.logger.info(f"Generated questions saved to {output_file}")
\ No newline at end of file
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.math import MathQuestionSynthesisPrompt
from dataflow.prompts.reasoning.general import GeneralQuestionSynthesisPrompt
from dataflow.prompts.reasoning.diy import DiyQuestionSynthesisPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
import pandas as pd
import random
import re
from typing import Union
import re
@prompt_restrict(
MathQuestionSynthesisPrompt,
GeneralQuestionSynthesisPrompt,
DiyQuestionSynthesisPrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningQuestionGenerator(OperatorABC):
def __init__(self,
num_prompts: int = 1,
llm_serving: LLMServingABC = None,
prompt_template: Union[MathQuestionSynthesisPrompt, GeneralQuestionSynthesisPrompt, DiyQuestionSynthesisPrompt, DIYPromptABC] = None
):
"""
Initialize the QuestionGenerator with the provided configuration.
"""
self.logger = get_logger()
if prompt_template is None:
prompt_template = MathQuestionSynthesisPrompt()
self.prompts = prompt_template
self.num_prompts = num_prompts
self.llm_serving = llm_serving
if self.num_prompts not in range(1,6):
self.logger.debug("num_prompts must be an integer between 1 and 5 (inclusive)")
@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于基于现有问题生成新问题。\n"
"输入参数:\n"
"- num_prompts:生成问题的数量,整数,范围1-5(含),默认1\n"
"- llm_serving:LLM服务实例,用于生成问题\n"
"- prompt_template:提示模板对象,用于构建生成提示词\n"
"输出参数:\n"
"- 原始输入列(由input_key指定):新增生成的问题\n"
"- Synth_or_Input:标识问题来源,'input'表示原始问题,'synth'表示生成的新问题"
)
elif lang == "en":
return (
"Generates new questions based on existing ones. \n"
"Input Parameters:\n"
"- num_prompts: Number of questions to generate per input, integer between 1-5 (inclusive), default 1\n"
"- llm_serving: LLM serving instance for question generation\n"
"- prompt_template: Prompt template object for constructing generation prompts\n"
"Output Parameters:\n"
"- Original input column (specified by input_key): Contains newly generated questions\n"
"- Synth_or_Input: Indicates question source, 'input' for original questions, 'synth' for generated questions"
)
elif lang == "en":
return (
"Generates new questions based on existing ones. "
"Produces 1-5 new questions per original question.\n"
"Input Parameters:\n"
"- eval_stage: Evaluation stage identifier\n"
"- read_min/max_score: Score filtering thresholds\n"
"- Other params same as base classifier\n"
"Output Parameters:\n"
"- generated_questions: List of newly generated questions"
)
def _validate_dataframe(self, dataframe: pd.DataFrame):
required_keys = [self.input_key]
forbidden_keys = ["Synth_or_Input"]
missing = [k for k in required_keys if k not in dataframe.columns]
conflict = [k for k in forbidden_keys if k in dataframe.columns]
if missing:
raise ValueError(f"Missing required column(s): {missing}")
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")
def _reformat_prompt(self, dataframe):
"""
Reformat the prompts in the dataframe to generate questions based on num_prompts.
"""
diversity_mode = [
"1, 2, 3",
"1, 2, 4",
"1, 2, 5",
"1, 4, 5",
"1, 2, 3, 4, 5"
]
formatted_prompts = []
for question in dataframe[self.input_key]:
if self.num_prompts == 0:
formatted_prompts.append("") # Skip generating for this question
else:
if not isinstance(self.prompts, DiyQuestionSynthesisPrompt):
# Randomly choose the required number of transformations from diversity_mode
selected_items = random.sample(diversity_mode, self.num_prompts)
for selected_item in selected_items:
used_prompt = self.prompts.build_prompt(selected_item, question)
formatted_prompts.append(used_prompt.strip())
else: ### diy prompt
try:
used_prompt = self.prompts.build_prompt(question=question)
formatted_prompts.append(used_prompt.strip())
except:
self.logger.debug(f"Please check if the symbol {{question}} in prompt is missing.")
return formatted_prompts
def _parse_response(self, response: str):
# useful for reasoning llms. If response is in format of Deepseek thinking: <think>...</think><answer>...</answer>, keep only the answer part.
pattern = r"<think>(.*?)</think><answer>(.*?)</answer>"
matches = re.findall(pattern, response)
if matches:
return matches[0][1]
else:
return response
def run(
self,
storage: DataFlowStorage,
input_key: str,
output_synth_or_input_flag: str = "Synth_or_Input"
):
"""
Run the question generation process.
"""
self.input_key = input_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._reformat_prompt(dataframe)
responses = self.llm_serving.generate_from_input(formatted_prompts)
responses = [self._parse_response(response) for response in responses]
new_rows = pd.DataFrame({
input_key: responses,
})
new_rows[output_synth_or_input_flag] = "synth"
dataframe[output_synth_or_input_flag] = "input"
dataframe = pd.concat([dataframe, new_rows], ignore_index=True)
dataframe = dataframe[dataframe[input_key].notna()]
dataframe = dataframe[dataframe[input_key] != ""]
output_file = storage.write(dataframe)
self.logger.info(f"Generated questions saved to {output_file}")
return [input_key, output_synth_or_input_flag]
from typing import TYPE_CHECKING
if TYPE_CHECKING:
# filter
from filter.sql_consistency_filter import SQLConsistencyFilter
from filter.sql_execution_filter import SQLExecutionFilter
# generate
from generate.sql_generator import SQLGenerator
from generate.sql_by_column_generator import SQLByColumnGenerator
from generate.sql_variation_generator import SQLVariationGenerator
from generate.text2sql_cot_generator import Text2SQLCoTGenerator
from generate.text2sql_prompt_generator import Text2SQLPromptGenerator
from generate.text2sql_question_generator import Text2SQLQuestionGenerator
# eval
from eval.sql_component_classifier import SQLComponentClassifier
from eval.sql_execution_classifier import SQLExecutionClassifier
else:
import sys
from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking
cur_path = "dataflow/operators/text2sql/"
_import_structure = generate_import_structure_from_type_checking(__file__, cur_path)
sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/text2sql/", _import_structure)
from tqdm import tqdm
from nltk import word_tokenize
import pandas as pd
import sqlite3
import re
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
class Schema:
def __init__(self, schema):
self._schema = self._normalize_schema(schema)
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _normalize_schema(self, schema):
normalized = {}
for table, cols in schema.items():
table_clean = table.strip().lower()
normalized[table_clean] = [col.strip().lower() for col in cols]
return normalized
def _map(self, schema):
idMap = {'*': "__all__"}
for table, cols in schema.items():
for col in cols:
full = f"{table}.{col}"
idMap[full] = f"__{full}__"
for table in schema:
idMap[table] = f"__{table}__"
return idMap
class EvalHardness:
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
FUNC_OPS = ('cast', 'substring', 'date', 'round', 'coalesce')
def __init__(self, schema, query):
self.schema = schema
self.query = query
@property
def tokenize(self):
string = str(self.query)
vals = {}
def replace_literal(match):
val = match.group(0)
quote = val[0]
content = val[1:-1]
if quote == '`':
key = f"__col_{len(vals)}__"
else: # ' 或 "
key = f"__str_{len(vals)}__"
vals[key] = content
return key
string = re.sub(r"(['\"`])(?:\\.|[^\\])*?\1", replace_literal, string)
toks = [word.lower() for word in word_tokenize(string)]
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(self, toks):
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1]
return alias
def get_tables_with_alias(self, toks):
tables = self.scan_alias(toks)
for key in self.schema.schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(self, toks, start_idx, tables_with_alias, schema, default_tables=None):
tok = toks[start_idx].strip().lower()
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok:
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(self, toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
if toks[idx] in self.FUNC_OPS:
func_name = toks[idx]
idx += 1
assert toks[idx] == '('
idx += 1
if func_name == 'cast':
idx, col_id = self.parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'as'
idx += 1
data_type = toks[idx]
idx += 1
assert toks[idx] == ')'
idx += 1
func_call = ('func', func_name, col_id, data_type)
return idx, (self.AGG_OPS.index("none"), func_call, False)
else:
idx, col_id = self.parse_col(toks, idx, tables_with_alias, schema, default_tables)
while toks[idx] != ')':
idx += 1
idx += 1
func_call = ('func', func_name, col_id, None)
return idx, (self.AGG_OPS.index("none"), func_call, False)
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in self.AGG_OPS:
agg_id = self.AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = self.parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = self.AGG_OPS.index("none")
idx, col_id = self.parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(self, toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
if toks[idx] in self.FUNC_OPS:
func_name = toks[idx]
idx += 1
assert toks[idx] == '('
idx += 1
if func_name == 'cast':
idx, col_id = self.parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'as'
idx += 1
data_type = toks[idx]
idx += 1
assert toks[idx] == ')'
idx += 1
func_call = ('func', func_name, col_id, data_type)
return idx, (self.AGG_OPS.index("none"), func_call, False)
else:
idx, col_id = self.parse_col(toks, idx, tables_with_alias, schema, default_tables)
while toks[idx] != ')':
idx += 1
idx += 1
func_call = ('func', func_name, col_id, None)
return idx, (self.AGG_OPS.index("none"), func_call, False)
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = self.UNIT_OPS.index('none')
idx, col_unit1 = self.parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in self.UNIT_OPS:
unit_op = self.UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = self.parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(self, toks, start_idx, tables_with_alias, schema):
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(self, toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = self.parse_sql(toks, idx, tables_with_alias, schema)
elif isinstance(toks[idx], str) and toks[idx] not in schema.idMap:
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] not in (',', ')', 'and', *self.CLAUSE_KEYWORDS, *self.JOIN_KEYWORDS):
end_idx += 1
idx, val = self.parse_col_unit(toks[start_idx:end_idx], 0, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(self, toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = self.parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in self.WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = self.WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == self.WHERE_OPS.index('between'):
idx, val1 = self.parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = self.parse_value(toks, idx, tables_with_alias, schema, default_tables)
else:
idx, val1 = self.parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in self.CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in self.JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in self.COND_OPS:
conds.append(toks[idx])
idx += 1
return idx, conds
def parse_select(self, toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in self.CLAUSE_KEYWORDS:
agg_id = self.AGG_OPS.index("none")
if toks[idx] in self.AGG_OPS:
agg_id = self.AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = self.parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1
return idx, (isDistinct, val_units)
def parse_from(self, toks, start_idx, tables_with_alias, schema):
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = self.parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((self.TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1
idx, table_unit, table_name = self.parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((self.TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1
idx, this_conds = self.parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in self.CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(self, toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = self.parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(self, toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in self.CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = self.parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1
else:
break
return idx, col_units
def parse_order_by(self, toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in self.CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = self.parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in self.ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1
else:
break
return idx, (order_type, val_units)
def parse_having(self, toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = self.parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(self, toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
if type(toks[idx-1]) != int:
return idx, 1
return idx, int(toks[idx-1])
return idx, None
def skip_semicolon(self, toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx
def parse_sql(self, start_idx):
toks = self.tokenize
tables_with_alias = self.get_tables_with_alias(toks)
schema =self.schema
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
from_end_idx, table_units, conds, default_tables = self.parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = self.parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = self.parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = self.parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = self.parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = self.parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = self.parse_limit(toks, idx)
sql['limit'] = limit_val
idx = self.skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1
idx = self.skip_semicolon(toks, idx)
# intersect/union/except clause
for op in self.SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in self.SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = self.parse_sql(idx)
sql[sql_op] = IUE_sql
return idx, sql
def has_agg(self, unit):
return unit[0] != self.AGG_OPS.index('none')
def count_agg(self, units):
return len([unit for unit in units if self.has_agg(unit)])
def get_nestedSQL(self, sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def count_component1(self, sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0:
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == self.WHERE_OPS.index('like')])
return count
def count_component2(self, sql):
nested = self.get_nestedSQL(sql)
return len(nested)
def count_others(self, sql):
count = 0
# number of aggregation
agg_count = self.count_agg(sql['select'][1])
agg_count += self.count_agg(sql['where'][::2])
agg_count += self.count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += self.count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += self.count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
def eval_hardness(self, sql):
count_comp1_ = self.count_component1(sql)
count_comp2_ = self.count_component2(sql)
count_others_ = self.count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def run(self):
_, sql = self.parse_sql(0)
hardness = self.eval_hardness(sql)
return hardness
class EvalHardnessLite:
def __init__(self, sql: str, difficulty_config: dict):
self.sql = sql.lower()
self.difficulty_config = difficulty_config
def match(self, pattern):
return bool(re.search(pattern, self.sql))
def count_keyword(self, keyword):
return self.sql.count(keyword)
def classify_difficulty(self, score: int) -> str:
thresholds = self.difficulty_config['thresholds']
labels = self.difficulty_config['labels']
for i, threshold in enumerate(thresholds):
if score <= threshold:
return labels[i]
return labels[-1]
def run(self):
sql = self.sql
score = 0
if self.match(r'\( *select'):
score += 2
if self.count_keyword(' join ') > 0:
score += self.count_keyword(' join ')
if self.count_keyword(',') > 0 and 'from' in sql:
score += 1
if self.count_keyword(' and ') + self.count_keyword(' or ') >= 2:
score += 1
if any(kw in sql for kw in ['in', 'exists', 'like']):
score += 1
if 'group by' in sql:
score += 1
if 'having' in sql:
score += 1
if any(func in sql for func in ['cast', 'round', 'substring', 'date', 'coalesce']):
score += 1
if 'order by' in sql:
score += 1
if 'limit' in sql:
score += 1
if any(op in sql for op in ['union', 'intersect', 'except']):
score += 2
select_cols = re.findall(r'select\s+(distinct\s+)?(.+?)\s+from', sql, re.DOTALL)
if select_cols:
num_commas = select_cols[0][1].count(',')
if num_commas >= 1:
score += 1
difficulty = self.classify_difficulty(score)
return difficulty
@OPERATOR_REGISTRY.register()
class SQLComponentClassifier(OperatorABC):
def __init__(self,
difficulty_thresholds: list,
difficulty_labels: list
):
self.difficulty_config = {
'thresholds': difficulty_thresholds,
'labels': difficulty_labels
}
self.logger = get_logger()
if len(self.difficulty_config['thresholds']) != len(self.difficulty_config['labels']) - 1:
raise ValueError("Thresholds and labels configuration mismatch")
def check_column(self, dataframe):
required_columns = [self.input_sql_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"根据SQL的组件数量和复杂度,评估SQL的难度。\n\n"
"输入参数:\n"
"- input_sql_key: 输入SQL列名\n\n"
"输出参数:\n"
"- output_difficulty_key: 输出难度列名"
)
elif lang == "en":
return (
"This operator evaluates the difficulty of SQL components based on the number and complexity of components.\n\n"
"Input parameters:\n"
"- input_sql_key: The name of the input SQL column\n\n"
"Output parameters:\n"
"- output_difficulty_key: The name of the output difficulty column"
)
else:
return "SQL component difficulty evaluator for Text2SQL tasks."
def report_statistics(self, dataframe: pd.DataFrame):
counts = dataframe[self.output_difficulty_key].value_counts()
self.logger.info("SQL Difficulty Statistics")
difficulty_counts = {d: counts.get(d, 0) for d in ['easy', 'medium', 'hard', 'extra']}
self.logger.info(" | ".join([f"{d.title()}: {v}" for d, v in difficulty_counts.items()]))
def run(self, storage: DataFlowStorage,
input_sql_key: str = "SQL",
output_difficulty_key: str = "sql_component_difficulty"):
self.input_sql_key = input_sql_key
self.output_difficulty_key = output_difficulty_key
dataframe = storage.read("dataframe")
self.check_column(dataframe)
for idx, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Processing"):
sql = row.get(self.input_sql_key)
sql_hardness = EvalHardnessLite(sql, self.difficulty_config)
hardness = sql_hardness.run()
dataframe.at[idx, self.output_difficulty_key] = hardness
self.report_statistics(dataframe)
output_file = storage.write(dataframe)
self.logger.info(f"Extracted answers saved to {output_file}")
return [self.output_difficulty_key]
import pandas as pd
import os
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
@OPERATOR_REGISTRY.register()
class SQLExecutionClassifier(OperatorABC):
def __init__(self,
llm_serving: LLMServingABC,
database_manager: DatabaseManager,
num_generations: int = 10,
difficulty_thresholds: list = [2, 5, 9],
difficulty_labels: list = ['extra', 'hard', 'medium', 'easy']
):
self.llm_serving = llm_serving
self.database_manager = database_manager
self.difficulty_config = {
'thresholds': difficulty_thresholds,
'labels': difficulty_labels
}
self.num_generations = num_generations
self.timeout = 5.0 # Default timeout for SQL execution
self.logger = get_logger()
if self.num_generations <= self.difficulty_config["thresholds"][-1]:
nearest_multiple = ((self.difficulty_config["thresholds"][-1] // 5) + 1) * 5
self.logger.warning(f"num_generations is less than the last threshold, will be set to {nearest_multiple}")
self.num_generations = nearest_multiple
if len(self.difficulty_config['thresholds']) != len(self.difficulty_config['labels']) - 1:
raise ValueError("Thresholds and labels configuration mismatch")
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"让模型根据自然语言问题、数据库Schema和提示词,多次生成SQL,通过生成SQL的准确率,评估该问题对于模型的难度。\n\n"
"输入参数:\n"
"- input_db_id_key: 输入数据库ID列名\n"
"- input_sql_key: 输入SQL列名\n"
"- input_prompt_key: 输入prompt列名\n\n"
"输出参数:\n"
"- output_difficulty_key: 输出难度列名"
)
elif lang == "en":
return (
"This operator evaluates the difficulty of SQL generation for a question based on the accuracy of generated SQLs.\n\n"
"Input parameters:\n"
"- input_db_id_key: The name of the input database ID column\n"
"- input_sql_key: The name of the input SQL column\n"
"- input_prompt_key: The name of the input prompt column\n\n"
"Output parameters:\n"
"- output_difficulty_key: The name of the output difficulty column"
)
else:
return "SQL execution difficulty evaluator for Text2SQL tasks."
def check_column(self, dataframe):
required_columns = [self.input_db_id_key, self.input_sql_key, self.input_prompt_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
@staticmethod
def parse_response(response):
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
last_sql = sql_blocks[-1].strip()
return last_sql
else:
return ""
@staticmethod
def execute_model_batch(predicted_sqls_list, ground_truth_list, database_manager, db_ids, idxs, meta_time_out, logger):
comparisons = []
sql_mapping = {}
comparison_idx = 0
for i, (predicted_sqls, ground_truth, db_id, idx) in enumerate(zip(predicted_sqls_list, ground_truth_list, db_ids, idxs)):
for j, predicted_sql in enumerate(predicted_sqls):
comparison = (db_id, predicted_sql, ground_truth)
comparisons.append(comparison)
sql_mapping[comparison_idx] = {
'original_idx': i,
'sql_idx': j,
'idx': idx,
'sql': predicted_sql
}
comparison_idx += 1
try:
batch_results = database_manager.batch_compare_queries(comparisons)
except Exception as e:
logger.error(f"Batch comparison failed: {e}")
results = []
for i, (predicted_sqls, _, _, idx) in enumerate(zip(predicted_sqls_list, ground_truth_list, db_ids, idxs)):
result_data = []
for predicted_sql in predicted_sqls:
result_data.append({'res': 0, 'sql': predicted_sql, 'error': 'batch_execution_failed'})
results.append({"idx": idx, "cnt_true": -1, "results": result_data})
return results
results = {}
for i, (predicted_sqls, _, _, idx) in enumerate(zip(predicted_sqls_list, ground_truth_list, db_ids, idxs)):
results[i] = {
"idx": idx,
"cnt_true": 0,
"results": [None] * len(predicted_sqls)
}
for batch_idx, comparison_result in enumerate(batch_results):
original_idx = sql_mapping[batch_idx]['original_idx']
sql_idx = sql_mapping[batch_idx]['sql_idx']
if comparison_result['result1_success'] and comparison_result['result2_success']:
res = 1 if comparison_result['equal'] else 0
results[original_idx]['results'][sql_idx] = {
'res': res,
'sql': sql_mapping[batch_idx]['sql'],
'error': None
}
if res == 1:
results[original_idx]['cnt_true'] += 1
else:
error_msg = ""
if not comparison_result['result1_success']:
error_msg += f"Predicted SQL failed; "
if not comparison_result['result2_success']:
error_msg += f"Ground truth SQL failed"
results[original_idx]['results'][sql_idx] = {
'res': 0,
'sql': sql_mapping[batch_idx]['sql'],
'error': error_msg
}
return [results[i] for i in sorted(results.keys())]
def run_sqls_parallel(self, datas, database_manager, num_cpus, meta_time_out):
# pbar = tqdm(total=len(datas), desc="Executing SQLs")
exec_result = []
predicted_sqls_list = []
ground_truth_list = []
db_ids = []
idxs = []
for i, data_pair in enumerate(datas):
predicted_sqls = data_pair[self.output_predicted_sqls_key]
ground_truth = data_pair[self.input_sql_key]
db_id = data_pair[self.input_db_id_key].replace('\n', '')
db_id = re.sub(r'[^A-Za-z0-9_]', '', db_id)
predicted_sqls_list.append(predicted_sqls)
ground_truth_list.append(ground_truth)
db_ids.append(db_id)
idxs.append(i)
# batch_size = max(1, len(datas) // num_cpus) if num_cpus > 1 else len(datas)
batch_size = len(datas)
def process_batch(batch_data):
batch_predicted_sqls, batch_ground_truth, batch_db_ids, batch_idxs = batch_data
# Note: self.timeout is not defined, so I am removing it from the call
return SQLExecutionClassifier.execute_model_batch(
batch_predicted_sqls, batch_ground_truth, database_manager,
batch_db_ids, batch_idxs, meta_time_out, self.logger
)
batches = []
for i in range(0, len(datas), batch_size):
end_idx = min(i + batch_size, len(datas))
batch = (
predicted_sqls_list[i:end_idx],
ground_truth_list[i:end_idx],
db_ids[i:end_idx],
idxs[i:end_idx]
)
batches.append(batch)
with ThreadPoolExecutor(max_workers=num_cpus) as executor:
futures = [executor.submit(process_batch, batch) for batch in batches]
for future in as_completed(futures):
try:
batch_results = future.result()
exec_result.extend(batch_results)
# pbar.update(len(batch_results))
except Exception as e:
self.logger.warning(f"Error in batch SQL execution: {e}")
# Add default results for failed batches to ensure all indices are covered
batch_idx = futures.index(future)
batch = batches[batch_idx]
batch_idxs = batch[3] # Get the indices for this batch
for idx in batch_idxs:
exec_result.append({
"idx": idx,
"cnt_true": 0,
"results": []
})
# pbar.update(batch_size)
# pbar.close()
return exec_result
def sort_results(self, list_of_dicts):
# 增加对空列表的判断
if not list_of_dicts:
return []
return sorted(list_of_dicts, key=lambda x: x['idx'])
def report_statistics(self, dataframe: pd.DataFrame):
# 增加对列是否存在的判断
if self.output_difficulty_key not in dataframe.columns:
self.logger.warning(f"Column '{self.output_difficulty_key}' not found in dataframe for reporting stats. Skipping.")
return
counts = dataframe[self.output_difficulty_key].value_counts()
self.logger.info("SQL Difficulty Statistics")
stats = [f"{difficulty.title()}: {counts.get(difficulty, 0)}" for difficulty in ['easy', 'medium', 'hard', 'extra']]
self.logger.info(", ".join(stats))
def classify_difficulty(self, score):
if score == -1:
return "gold error"
thresholds = self.difficulty_config['thresholds']
labels = self.difficulty_config['labels']
for i, threshold in enumerate(thresholds):
if score <= threshold:
return labels[i]
return labels[-1]
def run(self, storage: DataFlowStorage,
input_db_id_key: str = "db_id",
input_sql_key: str = "SQL",
input_prompt_key: str = "rl_prompt",
output_difficulty_key: str = "sql_execution_difficulty"
):
self.input_sql_key = input_sql_key
self.input_prompt_key = input_prompt_key
self.input_db_id_key = input_db_id_key
self.output_difficulty_key = output_difficulty_key
self.output_predicted_sqls_key = "_temp_predicted_sqls"
self.output_cnt_true_key = "_temp_cnt_true"
dataframe = storage.read("dataframe")
self.check_column(dataframe)
input_prompts = dataframe[self.input_prompt_key].tolist()
self.logger.info(f"Processing {len(input_prompts)} questions, generating {self.num_generations} SQLs each...")
prompts = [q for q in input_prompts for _ in range(self.num_generations)]
responses = self.llm_serving.generate_from_input(prompts, system_prompt="You are a helpful assistant.")
datas = dataframe.to_dict(orient='records')
for i, data in enumerate(datas):
start_idx = i * self.num_generations
end_idx = start_idx + self.num_generations
question_responses = responses[start_idx:end_idx]
parsed_sqls = []
for response in question_responses:
if response:
# Calling the static method correctly now
parsed_sql = SQLExecutionClassifier.parse_response(response)
parsed_sqls.append(parsed_sql)
else:
parsed_sqls.append("")
data[self.output_predicted_sqls_key] = parsed_sqls
exec_result = self.run_sqls_parallel(datas, self.database_manager,
num_cpus=os.cpu_count(),
meta_time_out=5.0)
exec_result = self.sort_results(exec_result)
for execres in exec_result:
if execres is not None:
idx = execres["idx"]
cnt_true = execres["cnt_true"]
datas[idx][self.output_difficulty_key] = self.classify_difficulty(cnt_true)
datas[idx][self.output_cnt_true_key] = cnt_true
for data in datas:
data.pop(self.output_predicted_sqls_key, None)
data.pop(self.output_cnt_true_key, None)
dataframe = pd.DataFrame(datas)
self.report_statistics(dataframe)
# 增加对列是否存在的判断
if self.output_difficulty_key in dataframe.columns:
difficulty_counts = dataframe[self.output_difficulty_key].value_counts()
self.logger.info("\nDifficulty Distribution:")
for difficulty in ['easy', 'medium', 'hard', 'extra', 'gold error']:
count = difficulty_counts.get(difficulty, 0)
dataframe_len = len(dataframe) if dataframe is not None else 0
if dataframe_len > 0:
percentage = count / dataframe_len * 100
self.logger.info(
f" {difficulty}: {count} ({percentage:.1f}%)")
else:
self.logger.info(f" {difficulty}: {count} (0.0%)")
else:
self.logger.warning("Skipping difficulty distribution report because the column was not generated.")
output_file = storage.write(dataframe)
self.logger.info(f"Difficulty classification results saved to {output_file}")
return [self.output_difficulty_key]
from typing import Dict, Union
from tqdm import tqdm
import re
from dataflow.prompts.text2sql import SQLConsistencyFilterPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
@prompt_restrict(SQLConsistencyFilterPrompt)
@OPERATOR_REGISTRY.register()
class SQLConsistencyFilter(OperatorABC):
def __init__(self,
llm_serving: LLMServingABC,
database_manager: DatabaseManager,
prompt_template: Union[SQLConsistencyFilterPrompt, DIYPromptABC] = None
):
self.llm_serving = llm_serving
if prompt_template is None:
prompt_template = SQLConsistencyFilterPrompt()
self.prompt_template = prompt_template
self.database_manager = database_manager
self.logger = get_logger()
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"对条目进行过滤,检测SQL和自然语言问题是否对应,即判断SQL是否能解决该问题。\n\n"
"输入参数:\n"
"- input_sql_key: 输入SQL列名\n"
"- input_db_id_key: 输入数据库ID列名\n"
"- input_question_key: 输入问题列名\n\n"
)
elif lang == "en":
return (
"This operator filters items based on whether the SQL can solve the question.\n\n"
"Input parameters:\n"
"- input_sql_key: The name of the input SQL column\n"
"- input_db_id_key: The name of the input database ID column\n"
"- input_question_key: The name of the input question column\n\n"
)
else:
return "SQL consistency filter for Text2SQL tasks."
def _parse_consistency_response(self, response):
response_lower = response.lower() if response else ""
pattern = r"```\s*(.*?)\s*```"
ans_blocks = re.findall(pattern, response_lower, re.DOTALL)
for ans_block in ans_blocks:
if 'yes' in ans_block:
return True
return False
def check_column(self, dataframe):
required_columns = [self.input_sql_key, self.input_db_id_key, self.input_question_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def run(self, storage: DataFlowStorage,
input_sql_key: str = "SQL",
input_db_id_key: str = "db_id",
input_question_key: str = "question"
):
self.input_sql_key = input_sql_key
self.input_db_id_key = input_db_id_key
self.input_question_key = input_question_key
dataframe = storage.read("dataframe")
self.check_column(dataframe)
total_len = len(dataframe)
prompts = []
final_valid_indices = []
for _, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Processing consistency check"):
sql = row[self.input_sql_key]
question = row[self.input_question_key]
db_id = row[self.input_db_id_key]
db_details = self.database_manager.get_db_details(db_id)
prompt = self.prompt_template.build_prompt(question, sql, db_details)
prompts.append(prompt)
responses = self.llm_serving.generate_from_input(prompts, "")
for idx, response in enumerate(responses):
conclusion = self._parse_consistency_response(response)
if conclusion:
final_valid_indices.append(idx)
consistency_passed = len(final_valid_indices)
self.logger.info(f"Consistency check results: {consistency_passed} passed, total {total_len}")
if final_valid_indices:
filtered_dataframe = dataframe.loc[final_valid_indices].copy()
else:
self.logger.warning("No data passed all filters. Returning empty dataset.")
filtered_dataframe = dataframe.iloc[0:0].copy()
output_file = storage.write(filtered_dataframe)
self.logger.info(f"Filtered dataset saved to {output_file}")
return []
\ No newline at end of file
import re
import os
import pandas as pd
from tqdm import tqdm
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
@OPERATOR_REGISTRY.register()
class SQLExecutionFilter(OperatorABC):
def __init__(self, database_manager: DatabaseManager):
self.database_manager = database_manager
self.logger = get_logger()
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"对条目进行过滤,在数据库中执行SQL,筛选掉不可执行的条目。\n\n"
"输入参数:\n"
"- input_sql_key: 输入SQL列名\n"
"- input_db_id_key: 输入数据库ID列名\n\n"
)
elif lang == "en":
return (
"This operator filters items based on whether the SQL can be executed in the database.\n\n"
"Input parameters:\n"
"- input_sql_key: The name of the input SQL column\n"
"- input_db_id_key: The name of the input database ID column\n\n"
)
else:
return "SQL execution filter for Text2SQL tasks."
def filter_select_sql(self, sql):
'''
remain SELECT-type queries
'''
sql_wo_comments = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)
sql_wo_comments = re.sub(r'--.*', '', sql_wo_comments)
sql_wo_comments = sql_wo_comments.strip()
if sql_wo_comments.lower().startswith("select") or \
sql_wo_comments.lower().startswith("with"):
return True
return False
def check_column(self, dataframe):
required_columns = [self.input_sql_key, self.input_db_id_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def run(self, storage: DataFlowStorage,
input_sql_key: str = "sql",
input_db_id_key: str = "db_id"
):
self.input_sql_key = input_sql_key
self.input_db_id_key = input_db_id_key
dataframe = storage.read("dataframe")
self.check_column(dataframe)
db_id_need_to_check = dataframe[input_db_id_key].unique()
for db_id in db_id_need_to_check:
if not self.database_manager.database_exists(db_id):
self.logger.warning(f"Database {db_id} not found in registry, please check the database folder")
continue
self.logger.info(f"Start to filter {len(dataframe)} SQLs")
self.logger.info("Filtering SQLs using select component")
phase1_passed_indices = []
for idx, row in dataframe.iterrows():
sql = row[input_sql_key]
if self.filter_select_sql(sql):
phase1_passed_indices.append(idx)
self.logger.info(f"Phase 1 completed: {len(phase1_passed_indices)}/{len(dataframe)} SQLs passed initial filter")
db_ids = dataframe[input_db_id_key]
sql_list = dataframe[input_sql_key]
sql_triples = [(db_id, sql) for db_id, sql in zip(db_ids, sql_list)]
execution_results = self.database_manager.batch_execute_queries(sql_triples)
final_indices = []
for idx, exec_result in enumerate(execution_results):
if exec_result.success:
final_indices.append(idx)
self.logger.info(f"Filter completed, remaining {len(final_indices)} SQLs out of {len(dataframe)} original SQLs")
result_df = dataframe.loc[final_indices]
output_file = storage.write(result_df)
return []
\ No newline at end of file
import random
import pandas as pd
import re
from dataflow.prompts.text2sql import SelectSQLGeneratorPrompt
from dataflow.prompts.text2sql import SelectVecSQLGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from tqdm import tqdm
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
from typing import Union
@prompt_restrict(SelectSQLGeneratorPrompt, SelectVecSQLGeneratorPrompt)
@OPERATOR_REGISTRY.register()
class SQLByColumnGenerator(OperatorABC):
def __init__(self,
llm_serving: LLMServingABC,
database_manager: DatabaseManager,
generate_num: int = 5,
prompt_template: Union[SelectSQLGeneratorPrompt, SelectVecSQLGeneratorPrompt, DIYPromptABC] = None
):
self.llm_serving = llm_serving
self.logger = get_logger()
self.database_manager = database_manager
self.generate_num = generate_num
if prompt_template is None:
self.prompt_template = SelectSQLGeneratorPrompt()
else:
self.prompt_template = prompt_template
random.seed(42)
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"基于数据库信息,合成SQL,覆盖不同的难度、数据库Schema、函数和风格。\n\n"
"输出参数:\n"
"- output_sql_key: 输出SQL列名\n"
"- output_db_id_key: 数据库ID列名\n\n"
)
elif lang == "en":
return (
"This operator synthesizes SQL based on database information, covering different complexities, schemas, functions, and styles.\n\n"
"Output parameters:\n"
"- output_sql_key: The name of the output SQL column\n"
"- output_db_id_key: The name of the database ID column\n\n"
)
else:
return "SQL generator for Text2SQL tasks."
def parse_response(self, response):
if not response:
return ""
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
last_sql = sql_blocks[-1].strip()
return last_sql
else:
self.logger.warning("No SQL code block found in the response")
return ""
def run(self, storage: DataFlowStorage,
output_sql_key: str = "sql",
output_db_id_key: str = "db_id"
):
self.output_sql_key = output_sql_key
self.output_db_id_key = output_db_id_key
raw_dataframe = storage.read("dataframe")
db_names = self.database_manager.list_databases()
prompts = []
self.logger.info(f"Generating {self.generate_num} SQLs for each database")
for db_name in tqdm(db_names, desc="Processing Databases"):
special_col_count = self.database_manager.get_number_of_special_column(db_name)
sum_generate_num = special_col_count * self.generate_num
self.logger.info(f"Database '{db_name}' has {special_col_count} special columns. "
f"Generating {sum_generate_num} SQLs.")
create_statements, insert_statements = self.database_manager.get_create_statements_and_insert_statements(db_name)
for _ in range(sum_generate_num):
prompt = self.prompt_template.build_prompt(
insert_statements=insert_statements,
create_statements=create_statements,
db_engine=self.database_manager.db_type
)
prompts.append({"prompt": prompt, "db_id": db_name})
if not prompts:
self.logger.warning("No prompts generated, please check the database path and file")
return [self.output_sql_key, self.output_db_id_key]
db_ids = [data["db_id"] for data in prompts]
prompt_list = [data["prompt"] for data in prompts]
try:
responses = self.llm_serving.generate_from_input(prompt_list, "")
except Exception as e:
self.logger.error(f"Failed to generate SQLs: {e}")
responses = [""] * len(prompt_list)
results = [
{
self.output_db_id_key: db_id,
self.output_sql_key: self.parse_response(response)
}
for db_id, response in zip(db_ids, responses)
]
output_file = storage.write(pd.DataFrame(results))
return [self.output_sql_key, self.output_db_id_key]
import random
import pandas as pd
import re
from dataflow.prompts.text2sql import SelectSQLGeneratorPrompt
from dataflow.prompts.text2sql import SelectVecSQLGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from tqdm import tqdm
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
from typing import Union
@prompt_restrict(SelectSQLGeneratorPrompt, SelectVecSQLGeneratorPrompt)
@OPERATOR_REGISTRY.register()
class SQLGenerator(OperatorABC):
def __init__(self,
llm_serving: LLMServingABC,
database_manager: DatabaseManager,
generate_num: int = 300,
prompt_template: Union[SelectSQLGeneratorPrompt, SelectVecSQLGeneratorPrompt, DIYPromptABC] = None
):
self.llm_serving = llm_serving
self.logger = get_logger()
self.database_manager = database_manager
self.generate_num = generate_num
if prompt_template is None:
self.prompt_template = SelectSQLGeneratorPrompt()
else:
self.prompt_template = prompt_template
random.seed(42)
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"基于数据库信息,合成SQL,覆盖不同的难度、数据库Schema、函数和风格。\n\n"
"输出参数:\n"
"- output_sql_key: 输出SQL列名\n"
"- output_db_id_key: 数据库ID列名\n\n"
)
elif lang == "en":
return (
"This operator synthesizes SQL based on database information, covering different complexities, schemas, functions, and styles.\n\n"
"Output parameters:\n"
"- output_sql_key: The name of the output SQL column\n"
"- output_db_id_key: The name of the database ID column\n\n"
)
else:
return "SQL generator for Text2SQL tasks."
def parse_response(self, response):
if not response:
return ""
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
last_sql = sql_blocks[-1].strip()
return last_sql
else:
self.logger.warning("No SQL code block found in the response")
return ""
def run(self, storage: DataFlowStorage,
output_sql_key: str = "sql",
output_db_id_key: str = "db_id"
):
self.output_sql_key = output_sql_key
self.output_db_id_key = output_db_id_key
raw_dataframe = storage.read("dataframe")
db_names = self.database_manager.list_databases()
prompts = []
self.logger.info(f"Generating {self.generate_num} SQLs for each database")
for db_name in tqdm(db_names, desc="Processing Databases"):
sum_generate_num = self.generate_num
create_statements, insert_statements = self.database_manager.get_create_statements_and_insert_statements(db_name)
for _ in range(sum_generate_num):
prompt = self.prompt_template.build_prompt(
insert_statements=insert_statements,
create_statements=create_statements,
db_engine=self.database_manager.db_type
)
prompts.append({"prompt": prompt, "db_id": db_name})
if not prompts:
self.logger.warning("No prompts generated, please check the database path and file")
return [self.output_sql_key, self.output_db_id_key]
db_ids = [data["db_id"] for data in prompts]
prompt_list = [data["prompt"] for data in prompts]
try:
responses = self.llm_serving.generate_from_input(prompt_list, "")
except Exception as e:
self.logger.error(f"Failed to generate SQLs: {e}")
responses = [""] * len(prompt_list)
results = [
{
self.output_db_id_key: db_id,
self.output_sql_key: self.parse_response(response)
}
for db_id, response in zip(db_ids, responses)
]
output_file = storage.write(pd.DataFrame(results))
return [self.output_sql_key, self.output_db_id_key]
import random
import pandas as pd
import re
from dataflow.prompts.text2sql import SQLVariationGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from tqdm import tqdm
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import (DataFlowStorage, RESERVED_SYS_FIELD_LIST, RESERVED_USER_FIELD_LIST,
SYS_FIELD_PREFIX, USER_FIELD_PREFIX)
from dataflow.utils.text2sql.database_manager import DatabaseManager
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from typing import Union
@prompt_restrict(SQLVariationGeneratorPrompt)
@OPERATOR_REGISTRY.register()
class SQLVariationGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC,
database_manager: DatabaseManager,
num_variations: int = 10,
prompt_template: Union[SQLVariationGeneratorPrompt, DIYPromptABC] = None
):
self.llm_serving = llm_serving
self.logger = get_logger()
self.database_manager = database_manager
if prompt_template is None:
self.prompt_template = SQLVariationGeneratorPrompt()
else:
self.prompt_template = prompt_template
self.num_variations = num_variations
random.seed(42)
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"对于每个条目,基于已有的SQL,指导模型生成SQL的变种,即在原有SQL的基础上,进行数据替换、函数变换、难度变换等操作,生成更加丰富的SQL。\n\n"
"输入参数:\n"
"- input_sql_key: SQL列名\n"
"- input_db_id_key: 数据库ID列名\n\n"
)
elif lang == "en":
return (
"This operator generates variations of SQL based on existing SQLs, including data replacement, function transformation, and difficulty transformation, to generate more diverse SQLs.\n\n"
"Input parameters:\n"
"- input_sql_key: The name of the SQL column\n"
"- input_db_id_key: The name of the database ID column\n\n"
)
else:
return "SQL variation generator for Text2SQL tasks."
def parse_response(self, response):
if not response:
return ""
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
last_sql = sql_blocks[-1].strip()
return last_sql
else:
self.logger.warning("No SQL code block found in the response")
return ""
def check_column(self, dataframe):
required_columns = [self.input_sql_key, self.input_db_id_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def run(self, storage: DataFlowStorage,
input_sql_key: str = "sql",
input_db_id_key: str = "db_id"
):
self.input_sql_key = input_sql_key
self.input_db_id_key = input_db_id_key
dataframe = storage.read("dataframe")
self.check_column(dataframe)
original_count = len(dataframe)
prompts_and_metadata = []
original_row_indices = []
for row_idx, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Generating SQL Variations"):
try:
create_statements, insert_statements = self.database_manager.get_create_statements_and_insert_statements(row[self.input_db_id_key])
original_sql = row[self.input_sql_key]
for _ in range(self.num_variations):
prompt = self.prompt_template.build_prompt(
original_sql=original_sql,
create_statements=create_statements,
insert_statements=insert_statements,
db_engine=self.database_manager.db_type
)
prompts_and_metadata.append((
prompt,
row[self.input_db_id_key]
))
original_row_indices.append(row_idx)
except Exception as e:
self.logger.error(f"Error processing database {row[self.input_db_id_key]}: {e}")
continue
if prompts_and_metadata:
try:
prompts = [prompt for prompt, db_id in prompts_and_metadata]
responses = self.llm_serving.generate_from_input(prompts, system_prompt="")
for i, ((prompt, db_id), response) in enumerate(zip(prompts_and_metadata, responses)):
sql = self.parse_response(response)
if sql:
original_row_idx = original_row_indices[i]
original_row = dataframe.iloc[original_row_idx]
new_row = {col: None for col in dataframe.columns}
new_row[self.input_db_id_key] = db_id
new_row[self.input_sql_key] = sql
for sys_field in RESERVED_SYS_FIELD_LIST:
sys_col = f"{SYS_FIELD_PREFIX}{sys_field}"
if sys_col in dataframe.columns and sys_col in original_row:
new_row[sys_col] = original_row[sys_col]
for user_field in RESERVED_USER_FIELD_LIST:
user_col = f"{USER_FIELD_PREFIX}{user_field}"
if user_col in dataframe.columns and user_col in original_row:
new_row[user_col] = original_row[user_col]
dataframe = pd.concat([dataframe, pd.DataFrame([new_row])], ignore_index=True)
except Exception as e:
self.logger.error(f"Error generating SQL variations: {e}")
output_file = storage.write(dataframe)
self.logger.info(f"Generated {len(dataframe)} records (original: {original_count}, variations: {len(dataframe) - original_count})")
return []
\ No newline at end of file
from typing import Dict, Optional, Tuple, List, Union
import pandas as pd
import re
from dataflow.prompts.text2sql import Text2SQLCotGeneratorPrompt
from dataflow.core.prompt import DIYPromptABC, prompt_restrict
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
@prompt_restrict(Text2SQLCotGeneratorPrompt)
@OPERATOR_REGISTRY.register()
class Text2SQLCoTGenerator(OperatorABC):
def __init__(self, llm_serving: LLMServingABC,
database_manager: DatabaseManager,
prompt_template: Union[Text2SQLCotGeneratorPrompt, DIYPromptABC] = None
):
self.llm_serving = llm_serving
self.database_manager = database_manager
if prompt_template is None:
prompt_template = Text2SQLCotGeneratorPrompt()
self.prompt_template = prompt_template
self.logger = get_logger()
self.max_retries = 3
self.enable_retry = True
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"对于每个条目,生成从自然语言问题和数据库Schema到SQL的CoT长链路推理过程。\n\n"
"输入参数:\n"
"- input_sql_key: 输入SQL列名\n"
"- input_question_key: 输入问题列名\n"
"- input_db_id_key: 输入数据库ID列名\n\n"
"输出参数:\n"
"- output_cot_key: 输出CoT列名"
)
elif lang == "en":
return (
"This operator generates CoT for SQL with long chain reasoning from natural language questions and database schemas.\n\n"
"Input parameters:\n"
"- input_sql_key: The name of the input SQL column\n"
"- input_question_key: The name of the input question column\n"
"- input_db_id_key: The name of the input database ID column\n\n"
"Output parameters:\n"
"- output_cot_key: The name of the output CoT column"
)
else:
return "CoT generator for Text2SQL tasks with long chain reasoning."
def check_column(self, dataframe):
required_columns = [self.input_sql_key, self.input_db_id_key, self.input_question_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def extract_sql(self, response):
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
last_sql = sql_blocks[-1].strip()
return last_sql
else:
return ""
def _parse_response(self, response: str, gold_sql: str, db_id: str) -> Tuple[Optional[str], bool]:
generated_sql = self.extract_sql(response)
if not generated_sql:
return None, False
try:
ans = self.database_manager.compare_queries(db_id, generated_sql, gold_sql)
if ans:
return generated_sql, True
return generated_sql, False
except Exception as e:
self.logger.error(f"SQL execution failed: {db_id}, Error: {e}")
return generated_sql, False
def _process_items_with_retry(self, items: List[Dict], max_retries: int = 3) -> List[Dict]:
results = []
failed_items = items.copy()
for retry_round in range(max_retries):
if not failed_items:
break
self.logger.info(f"Start {retry_round + 1} round processing, {len(failed_items)} items to process")
prompts = []
for item in failed_items:
db_id = item.get(self.input_db_id_key)
question = item.get(self.input_question_key)
sql = item.get(self.input_sql_key)
evidence = item.get(self.input_evidence_key)
create_statements, _ = self.database_manager.get_create_statements_and_insert_statements(db_id)
schema_str = "\n\n".join(create_statements)
cot_prompt = self.prompt_template.build_prompt(schema_str, question, sql, evidence)
prompts.append(cot_prompt)
cot_responses = self.llm_serving.generate_from_input(prompts, "")
comparisons = []
valid_items_with_responses = []
for item, response in zip(failed_items, cot_responses):
db_id = item.get(self.input_db_id_key)
gold_sql = item.get(self.input_sql_key)
generated_sql = self.extract_sql(response)
if generated_sql:
comparisons.append((db_id, generated_sql, gold_sql))
valid_items_with_responses.append((item, response, generated_sql))
if comparisons:
try:
batch_results = self.database_manager.batch_compare_queries(comparisons)
current_round_failed = []
for (item, response, generated_sql), batch_result in zip(valid_items_with_responses, batch_results):
db_id = item.get(self.input_db_id_key)
if batch_result.get('equal', False):
results.append({
**item,
self.output_cot_key: response
})
self.logger.debug(f"Successfully processed {db_id} (Round {retry_round + 1})")
else:
current_round_failed.append(item)
if batch_result.get('differences'):
self.logger.debug(f"SQL comparison failed for {db_id}: {batch_result['differences']}")
for item, response in zip(failed_items, cot_responses):
if item not in [valid_item for valid_item, _, _ in valid_items_with_responses]:
current_round_failed.append(item)
except Exception as e:
self.logger.error(f"Batch SQL comparison failed: {e}")
current_round_failed = []
for item, response in zip(failed_items, cot_responses):
db_id = item.get(self.input_db_id_key)
gold_sql = item.get(self.input_sql_key)
parsed_response, success = self._parse_response(response, gold_sql, db_id)
if success and parsed_response:
results.append({
**item,
self.output_cot_key: response
})
self.logger.debug(f"Successfully processed {db_id} (Round {retry_round + 1})")
else:
current_round_failed.append(item)
else:
current_round_failed = failed_items
failed_items = current_round_failed
self.logger.info(f"Text2SQL CoT Generation: Round {retry_round + 1} completed, Success: {len(results)}, Failed: {len(failed_items)}")
if failed_items:
self.logger.warning(f"Still {len(failed_items)} items failed, will be discarded")
for item in failed_items:
self.logger.debug(f"Discarded failed item: {item.get(self.input_db_id_key)}")
return results
def run(self, storage: DataFlowStorage,
input_sql_key: str = "SQL",
input_question_key: str = "question",
input_db_id_key: str = "db_id",
input_evidence_key: str = "evidence",
output_cot_key: str = "cot_reasoning"
):
self.input_question_key = input_question_key
self.input_sql_key = input_sql_key
self.input_db_id_key = input_db_id_key
self.input_evidence_key = input_evidence_key
self.output_cot_key = output_cot_key
self.logger.info("Starting CoT generation...")
raw_dataframe = storage.read("dataframe")
self.check_column(raw_dataframe)
items = raw_dataframe.to_dict('records')
results = self._process_items_with_retry(items, self.max_retries)
if not results:
self.logger.warning("No CoT results generated")
return []
output_df = pd.DataFrame(results)
output_file = storage.write(output_df)
self.logger.info(f"CoT generation completed, saved to {output_file}")
self.logger.info(f"Processed {len(results)} items, original {len(items)} items")
return [self.output_cot_key]
\ No newline at end of file
import pandas as pd
import re
from tqdm import tqdm
from typing import Dict, Optional, Union
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.prompts.text2sql import Text2SQLPromptGeneratorPrompt, Text2VecSQLPromptGeneratorPrompt
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
from dataflow.core import OperatorABC
from dataflow.utils.storage import DataFlowStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
@prompt_restrict(Text2SQLPromptGeneratorPrompt, Text2VecSQLPromptGeneratorPrompt)
@OPERATOR_REGISTRY.register()
class Text2SQLPromptGenerator(OperatorABC):
def __init__(self,
database_manager: DatabaseManager,
prompt_template: Union[Text2SQLPromptGeneratorPrompt, Text2VecSQLPromptGeneratorPrompt, DIYPromptABC] = None
):
if prompt_template is None:
prompt_template = Text2SQLPromptGeneratorPrompt()
self.prompt_template = prompt_template
self.logger = get_logger()
self.database_manager = database_manager
@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"从数据库提取Schema信息,结合自然语言问题生成提示词。其中提示词模版支持自定义。\n\n"
"输入参数:\n"
"- input_question_key: 问题列名\n"
"- input_db_id_key: 数据库ID列名\n"
"- output_prompt_key: 输出prompt列名\n\n"
"输出参数:\n"
"- output_prompt_key: 生成的prompt"
)
elif lang == "en":
return (
"This operator generates prompts for Text2SQL tasks by extracting schema information from databases and combining it with natural language questions. The prompt template can be customized.\n\n"
"Input parameters:\n"
"- input_question_key: The name of the question column\n"
"- input_db_id_key: The name of the database ID column\n"
"- output_prompt_key: The name of the output prompt column\n\n"
"Output parameters:\n"
"- output_prompt_key: The generated prompt"
)
else:
return "Prompt generator for Text2SQL tasks."
def get_create_statements_and_insert_statements(self, db_id: str) -> str:
return self.database_manager.get_create_statements_and_insert_statements(db_id)
def run(self, storage: DataFlowStorage,
input_question_key: str = "question",
input_db_id_key: str = "db_id",
input_evidence_key: str = "evidence",
output_prompt_key: str = "prompt"
):
self.input_question_key = input_question_key
self.input_db_id_key = input_db_id_key
self.input_evidence_key = input_evidence_key
self.output_prompt_key = output_prompt_key
self.logger.info("Starting prompt generation...")
raw_dataframe = storage.read("dataframe")
required_cols = [input_question_key, input_db_id_key]
missing_cols = [col for col in required_cols if col not in raw_dataframe.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
items = raw_dataframe.to_dict('records')
final_results = []
for item in tqdm(items, desc="Generating prompts"):
db_id = item[self.input_db_id_key]
question = item[self.input_question_key]
if self.input_evidence_key in item:
evidence = item[self.input_evidence_key]
else:
evidence = ""
db_id = re.sub(r'[^A-Za-z0-9_]', '', str(db_id).replace('\n', ''))
db_details = self.database_manager.get_db_details(db_id)
prompt = self.prompt_template.build_prompt(
db_details=db_details,
question=question,
evidence=evidence,
db_engine=self.database_manager.db_type
)
result = {
**item,
self.output_prompt_key: prompt
}
final_results.append(result)
if len(final_results) != len(items):
self.logger.warning(f"Results count mismatch: expected {len(items)}, got {len(final_results)}")
output_file = storage.write(pd.DataFrame(final_results))
self.logger.info(f"Prompt generation completed, saved to {output_file}")
return [self.output_prompt_key]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment