Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
from dataflow.operators.core_text import PromptedFilter
from dataflow.serving import APILLMServing_request
from dataflow.utils.storage import FileStorage
class GPT_evaluator():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/core_text_data/eval_data.json",
cache_path="./cache_1",
file_name_prefix="math_QA",
cache_type="json",
)
self.llm_serving = APILLMServing_request(
api_url="https://api.openai.com/v1/chat/completions",
model_name="gpt-4o",
max_workers=10
)
self.prompt_evaluator = PromptedFilter(
llm_serving = self.llm_serving,
)
def forward(self):
# Initial filters
self.prompt_evaluator.run(
storage = self.storage.step(),
input_key = "conversations",
output_key = "eval_dim_1",
)
if __name__ == "__main__":
# This is the entry point for the pipeline
model = GPT_evaluator()
model.forward()
from dataflow.operators.core_text import PromptedRefiner
from dataflow.serving import APILLMServing_request
from dataflow.utils.storage import FileStorage
class GPT_generator():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/GeneralTextPipeline/abbreviation.jsonl",
cache_path="./cache",
file_name_prefix="math_QA",
cache_type="jsonl",
)
self.model_cache_dir = './dataflow_cache'
self.llm_serving = APILLMServing_request(
api_url="https://api.openai.com/v1/chat/completions",
model_name="gpt-4o",
max_workers=10
)
self.prompt_refiner = PromptedRefiner(
llm_serving = self.llm_serving,
system_prompt = "Please rewrite this sentence into a better one.", # System prompt for math problem solving
)
def forward(self):
# Initial filters
self.prompt_refiner.run(
storage = self.storage.step(),
input_key = "raw_content",
)
if __name__ == "__main__":
# This is the entry point for the pipeline
model = GPT_generator()
model.forward()
from dataflow.operators.code import (
CodeAutoGeneratedFilter,
CodeLengthSampleFilter,
CodeTextCompositionFilter,
CodeEncodedDataFilter,
CodeDocumentQualityFilter,
CodeFileTypeContentFilter,
CodeGenericScoreFilter,
)
from dataflow.utils.storage import FileStorage
class PTCodeFilter_CPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/CodePipeline/code_input.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.autogen_filter_step1 = CodeAutoGeneratedFilter(
min_score=1.0,
max_score=1.0
)
self.length_filter_step2 = CodeLengthSampleFilter(
min_score=1.0,
max_score=1.0
)
self.text_composition_filter_step3 = CodeTextCompositionFilter(
min_score=1.0,
max_score=1.0
)
self.encoded_data_filter_step4 = CodeEncodedDataFilter(
min_score=1.0,
max_score=1.0
)
self.doc_quality_filter_step5 = CodeDocumentQualityFilter(
min_score=1.0,
max_score=1.0,
thresholds={
'min_num_chars': 100,
'max_num_chars': 100000,
'min_num_words': 10,
'max_num_words': 50000,
'max_frac_duplicate_lines': 0.3,
'max_frac_duplicate_2gram': 0.3,
'max_frac_duplicate_3gram': 0.3,
'max_frac_duplicate_4gram': 0.3,
'max_frac_duplicate_5gram': 0.3,
'max_frac_curly_bracket': 0.1,
'max_frac_all_caps_words': 0.3,
'min_entropy_unigram': 2.0,
}
)
self.file_type_filter_step6 = CodeFileTypeContentFilter()
self.score_filter_step7 = CodeGenericScoreFilter()
def forward(self):
self.autogen_filter_step1.run(
storage=self.storage.step(),
input_key="lines",
output_key="autogen_filter_label"
)
self.length_filter_step2.run(
storage=self.storage.step(),
input_key="lines",
output_key="length_filter_label"
)
self.text_composition_filter_step3.run(
storage=self.storage.step(),
input_key="text",
output_key="text_composition_filter_label"
)
self.encoded_data_filter_step4.run(
storage=self.storage.step(),
input_key="text",
output_key="encoded_data_filter_label"
)
self.doc_quality_filter_step5.run(
storage=self.storage.step(),
input_key="text",
output_key="doc_quality_filter_label"
)
self.file_type_filter_step6.run(
storage=self.storage.step(),
input_key="dataframe",
output_key="file_type_filter_label"
)
# self.score_filter_step7.run(
# storage=self.storage.step(),
# input_key="quality_score",
# output_key="score_filter_label",
# score_threshold=8,
# filter_method="greater_equal"
# )
if __name__ == "__main__":
model = PTCodeFilter_CPUPipeline()
model.forward()
\ No newline at end of file
from dataflow.operators.knowledge_cleaning import (
KBCChunkGenerator,
FileOrURLToMarkdownConverterBatch
)
from dataflow.utils.storage import FileStorage
class KBCleaning_CPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/KBCleaningPipeline/kbc_test_1.jsonl",
cache_path="./.cache/cpu",
file_name_prefix="url_cleaning_step",
cache_type="json",
)
self.knowledge_cleaning_step1 = FileOrURLToMarkdownConverterBatch(
intermediate_dir="../example_data/KBCleaningPipeline/raw/",
lang="en",
mineru_backend="pipeline",
)
self.knowledge_cleaning_step2 = KBCChunkGenerator(
split_method="token",
chunk_size=512,
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
)
def forward(self):
self.knowledge_cleaning_step1.run(
storage=self.storage.step(),
# input_file=,
# output_key=,
)
self.knowledge_cleaning_step2.run(
storage=self.storage.step(),
# input_file=,
# output_key=,
)
if __name__ == "__main__":
model = KBCleaning_CPUPipeline()
model.forward()
from dataflow.operators.reasoning import (
ReasoningAnswerFormatterFilter,
ReasoningAnswerGroundTruthFilter,
ReasoningAnswerNgramFilter,
)
from dataflow.utils.storage import FileStorage
class Reasoning_CPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/ReasoningPipeline/pipeline_math_short.json",
cache_path="./cache_local",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.answer_format_filter_step1 = ReasoningAnswerFormatterFilter()
self.answer_groundtruth_filter_step2 = ReasoningAnswerGroundTruthFilter()
self.answer_ngram_filter_step3 = ReasoningAnswerNgramFilter(
min_score = 0.1,
max_score = 1.0,
ngrams = 5
)
def forward(self):
self.answer_format_filter_step1.run(
storage = self.storage.step(),
input_key = "output",
)
self.answer_groundtruth_filter_step2.run(
storage = self.storage.step(),
input_test_answer_key = "output",
input_gt_answer_key = "golden_answer"
)
self.answer_ngram_filter_step3.run(
storage = self.storage.step(),
input_question_key = "instruction",
input_answer_key = "output"
)
if __name__ == "__main__":
model = Reasoning_CPUPipeline()
model.forward()
import os
import zipfile
from dataflow import get_logger
from pathlib import Path
from huggingface_hub import snapshot_download
from dataflow.operators.text2sql import (
Text2SQLPromptGenerator
)
from dataflow.operators.text2sql import (
SQLExecutionFilter
)
from dataflow.operators.text2sql import (
SQLComponentClassifier
)
from dataflow.prompts.text2sql import (
Text2SQLPromptGeneratorPrompt
)
from dataflow.utils.storage import FileStorage
from dataflow.utils.text2sql.database_manager import DatabaseManager
def download_and_extract_database(logger):
dataset_repo_id = "Open-Dataflow/dataflow-Text2SQL-database-example"
local_dir = "./hf_cache"
extract_to = "./downloaded_databases"
logger.info(f"Downloading and extracting database from {dataset_repo_id}...")
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.makedirs(local_dir, exist_ok=True)
os.makedirs(extract_to, exist_ok=True)
downloaded_path = snapshot_download(
repo_id=dataset_repo_id,
repo_type="dataset",
local_dir=local_dir,
resume_download=True
)
logger.info(f"Files downloaded to: {downloaded_path}")
zip_path = os.path.join(downloaded_path, "databases.zip")
if os.path.exists(zip_path):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
logger.info(f"Database files extracted to {extract_to}")
return extract_to
else:
raise FileNotFoundError(f"Database zip file not found at {zip_path}")
class Text2SQL_CPUPipeline():
def __init__(self, db_root_path=""):
self.logger = get_logger()
self.db_root_path = db_root_path
if not db_root_path:
try:
self.db_root_path = download_and_extract_database(self.logger)
self.logger.info(f"Using automatically downloaded database at: {self.db_root_path}")
except Exception as e:
self.logger.error(f"Failed to auto-download database: {e}")
raise
else:
self.logger.info(f"Using manually specified database path: {self.db_root_path}")
if not os.path.exists(self.db_root_path):
raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}")
self.storage = FileStorage(
first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
# SQLite and MySQL are currently supported
# db_type can be sqlite or mysql, which must match your database type
# If sqlite is selected, root_path must be provided, this path must exist and contain database files
# If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions
# MySQL example:
# database_manager = DatabaseManager(
# db_type="mysql",
# config={
# "host": "localhost",
# "user": "root",
# "password": "your_password",
# "database": "your_database_name"
# }
# )
# SQLite example:
database_manager = DatabaseManager(
db_type="sqlite",
config={
"root_path": self.db_root_path
},
logger=None,
sql_execution_timeout = 2,
max_connections_per_db=100,
max_workers=100
)
self.sql_execution_filter_step1 = SQLExecutionFilter(
database_manager=database_manager,
)
self.text2sql_prompt_generator_step2 = Text2SQLPromptGenerator(
database_manager=database_manager,
prompt_template=Text2SQLPromptGeneratorPrompt()
)
self.sql_component_classifier_step3 = SQLComponentClassifier(
difficulty_thresholds=[2, 4, 6],
difficulty_labels=['easy', 'medium', 'hard', 'extra']
)
def forward(self):
sql_key = "SQL"
db_id_key = "db_id"
question_key = "question"
self.sql_execution_filter_step1.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key
)
self.text2sql_prompt_generator_step2.run(
storage=self.storage.step(),
input_question_key=question_key,
input_db_id_key=db_id_key,
output_prompt_key="prompt"
)
self.sql_component_classifier_step3.run(
storage=self.storage.step(),
input_sql_key=sql_key,
output_difficulty_key="sql_component_difficulty"
)
if __name__ == "__main__":
# If you have your own database files, you can set the db_root_path to the path of your database files
# If not, please set the db_root_path "", and we will download the example database files automatically
db_root_path = ""
model = Text2SQL_CPUPipeline(db_root_path=db_root_path)
model.forward()
from dataflow.operators.general_text import (
WordNumberFilter,
BlocklistFilter,
MinHashDeduplicateFilter,
ColonEndFilter,
SentenceNumberFilter,
LineEndWithEllipsisFilter,
ContentNullFilter,
MeanWordLengthFilter,
SymbolWordRatioFilter,
HtmlEntityFilter,
NoPuncFilter,
SpecialCharacterFilter,
WatermarkFilter,
CurlyBracketFilter,
CapitalWordsFilter,
LoremIpsumFilter,
UniqueWordsFilter,
CharNumberFilter,
LineStartWithBulletpointFilter,
LineWithJavascriptFilter,
HtmlUrlRemoverRefiner,
RemoveEmojiRefiner,
RemoveExtraSpacesRefiner
)
from dataflow.operators.text_pt import (
MetaSampleEvaluator,
)
from dataflow.utils.storage import FileStorage
class PTTextFilter_CPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/GeneralTextPipeline/pt_input.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.remove_extra_spaces_refiner = RemoveExtraSpacesRefiner()
self.remove_emoji_refiner = RemoveEmojiRefiner()
self.html_remove_refiner = HtmlUrlRemoverRefiner()
self.minhash_deduplicator = MinHashDeduplicateFilter(num_perm=128, threshold=0.9, use_n_gram=True, ngram=5)
self.blocklist_filter = BlocklistFilter()
self.word_number_filter = WordNumberFilter(min_words=20, max_words=100000)
self.colon_end_filter = ColonEndFilter()
self.sentence_number_filter = SentenceNumberFilter(min_sentences=3, max_sentences=7500)
self.line_end_with_ellipsis_filter = LineEndWithEllipsisFilter(threshold=0.3)
self.content_null_filter = ContentNullFilter()
self.mean_word_length_filter = MeanWordLengthFilter(min_length=3, max_length=10)
self.symbol_word_ratio_filter = SymbolWordRatioFilter(threshold=0.4)
self.html_entity_filter = HtmlEntityFilter()
self.no_punc_filter = NoPuncFilter(threshold=112)
self.special_character_filter = SpecialCharacterFilter()
self.watermark_filter = WatermarkFilter(watermarks=['Copyright', 'Watermark', 'Confidential'])
self.curly_bracket_filter = CurlyBracketFilter(threshold=0.025)
self.capital_words_filter = CapitalWordsFilter(threshold=0.2, use_tokenizer=False)
self.lorem_ipsum_filter = LoremIpsumFilter(threshold=3e-8)
self.unique_words_filter = UniqueWordsFilter(threshold=0.1)
self.char_number_filter = CharNumberFilter(threshold=100)
self.line_start_with_bulletpoint_filter = LineStartWithBulletpointFilter(threshold=0.9)
self.line_with_javascript_filter = LineWithJavascriptFilter(threshold=3)
def forward(self):
self.remove_emoji_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.html_remove_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.remove_extra_spaces_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.minhash_deduplicator.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.blocklist_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.word_number_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.colon_end_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.sentence_number_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.line_end_with_ellipsis_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.content_null_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.mean_word_length_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.symbol_word_ratio_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.html_entity_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.no_punc_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.special_character_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.watermark_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.curly_bracket_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.capital_words_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.lorem_ipsum_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.unique_words_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.char_number_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.line_start_with_bulletpoint_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.line_with_javascript_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
if __name__ == "__main__":
# This is the entry point for the pipeline
model = PTTextFilter_CPUPipeline()
model.forward()
from dataflow.operators.general_text import WordNumberFilter
from dataflow.utils.storage import FileStorage
class SFTTextFilter_CPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/GeneralTextPipeline/sft_input.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.model_cache_dir = './dataflow_cache'
self.word_number_filter_step1 = WordNumberFilter(
min_words=20,
max_words=1000
)
def forward(self):
self.word_number_filter_step1.run(
storage=self.storage.step(),
input_key="output",
)
if __name__ == "__main__":
# This is the entry point for the pipeline
pipeline = SFTTextFilter_CPUPipeline()
pipeline.forward()
\ No newline at end of file
from dataflow.operators.agentic_rag import (
AutoPromptGenerator,
QAGenerator,
QAScorer
)
from dataflow.operators.agentic_rag import (
ContentChooser
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
class AgenticRAG_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/AgenticRAGPipeline/pipeline_small_chunk.json",
cache_path="./cache_local",
file_name_prefix="dataflow_cache_step",
cache_type="json",
)
# use vllm as LLM serving
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=4,
vllm_max_tokens=8192,
)
# use SGLang as LLM serving
# llm_serving = LocalModelLLMServing_sglang(
# hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
# sgl_dp_size=1, # data parallel size
# sgl_tp_size=1, # tensor parallel size
# sgl_max_tokens=1024,
# sgl_tensor_parallel_size=4
# )
embedding_serving = LocalModelLLMServing_vllm(hf_model_name_or_path="Alibaba-NLP/gte-Qwen2-7B-instruct", vllm_max_tokens=8192)
self.content_chooser_step1 = ContentChooser(embedding_serving=embedding_serving, num_samples=5, method="random")
self.prompt_generator_step2 = AutoPromptGenerator(self.llm_serving)
self.qa_generator_step3 = QAGenerator(self.llm_serving)
self.qa_scorer_step4 = QAScorer(self.llm_serving)
def forward(self):
self.content_chooser_step1.run(
storage = self.storage.step(),
input_key = "text"
)
self.prompt_generator_step2.run(
storage = self.storage.step(),
input_key = "text"
)
self.qa_generator_step3.run(
storage = self.storage.step(),
input_key="text",
output_prompt_key="generated_prompt",
output_quesion_key="generated_question",
output_answer_key="generated_answer"
)
self.qa_scorer_step4.run(
storage = self.storage.step(),
input_question_key="generated_question",
input_answer_key="generated_answer",
output_question_quality_key="question_quality_grades",
output_question_quality_feedback_key="question_quality_feedbacks",
output_answer_alignment_key="answer_alignment_grades",
output_answer_alignment_feedback_key="answer_alignment_feedbacks",
output_answer_verifiability_key="answer_verifiability_grades",
)
if __name__ == "__main__":
model = AgenticRAG_GPUPipeline()
model.forward()
from dataflow.operators.core_text import BenchDatasetEvaluator
from dataflow.operators.reasoning import ReasoningAnswerGenerator
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm
from dataflow.core import LLMServingABC
from dataflow.prompts.reasoning.diy import (
DiyAnswerGeneratorPrompt,
)
DIY_PROMPT_ANSWER = """Please output the answer."""
class BenchEvalPipeline():
def __init__(self, llm_serving_generator: LLMServingABC = None, llm_serving_judger: LLMServingABC = None):
self.storage = FileStorage(
first_entry_file_name="../example_data/core_text_data/bench_eval_data.jsonl",
cache_path="./cache_local",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.llm_serving_generator = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=1,
vllm_max_tokens=2048,
)
self.answer_generator_step1 = ReasoningAnswerGenerator(
llm_serving=self.llm_serving_generator,
prompt_template=DiyAnswerGeneratorPrompt(DIY_PROMPT_ANSWER)
)
self.evaluator_step2 = BenchDatasetEvaluator(
eval_result_path="./cache_local/eval_result/eval_result.jsonl",
compare_method="match", # or semantic
prompt_template = None # you can diy your judger prompt in dataflow.prompts.reasoning.general.AnswerJudgePrompt
)
def forward(self):
self.answer_generator_step1.run(
storage = self.storage.step(),
input_key = "instruction",
output_key = "generated_cot"
)
self.evaluator_step2.run(
storage = self.storage.step(),
input_test_answer_key="generated_cot",
input_gt_answer_key="golden_answer",
input_question_key="instruction",
)
if __name__ == "__main__":
pl = BenchEvalPipeline()
pl.forward()
from dataflow.operators.knowledge_cleaning import (
KBCChunkGeneratorBatch,
FileOrURLToMarkdownConverterBatch,
KBCTextCleanerBatch,
KBCMultiHopQAGeneratorBatch,
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
class KBCleaning_batchSglang_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../../example_data/KBCleaningPipeline/kbc_test.jsonl",
cache_path="./.cache/gpu",
file_name_prefix="batch_cleaning_step",
cache_type="json",
)
self.knowledge_cleaning_step1 = FileOrURLToMarkdownConverterBatch(
intermediate_dir="../../example_data/KBCleaningPipeline/raw/",
lang="en",
mineru_backend="vlm-vllm-engine",
)
self.knowledge_cleaning_step2 = KBCChunkGeneratorBatch(
split_method="token",
chunk_size=512,
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
)
def forward(self):
self.knowledge_cleaning_step1.run(
storage=self.storage.step(),
)
self.knowledge_cleaning_step2.run(
storage=self.storage.step(),
)
self.llm_serving = LocalModelLLMServing_sglang(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
sgl_dp_size=1, # data parallel size
sgl_tp_size=1, # tensor parallel size
sgl_max_new_tokens=2048,
)
self.knowledge_cleaning_step3 = KBCTextCleanerBatch(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step4 = KBCMultiHopQAGeneratorBatch(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step3.run(
storage=self.storage.step(),
)
self.knowledge_cleaning_step4.run(
storage=self.storage.step(),
)
if __name__ == "__main__":
model = KBCleaning_batchSglang_GPUPipeline()
model.forward()
from dataflow.operators.knowledge_cleaning import (
KBCChunkGeneratorBatch,
FileOrURLToMarkdownConverterBatch,
KBCTextCleanerBatch,
KBCMultiHopQAGeneratorBatch,
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
class KBCleaning_batchvllm_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../../example_data/KBCleaningPipeline/kbc_test.jsonl",
cache_path="./.cache/gpu",
file_name_prefix="batch_cleaning_step",
cache_type="json",
)
self.knowledge_cleaning_step1 = FileOrURLToMarkdownConverterBatch(
intermediate_dir="../../example_data/KBCleaningPipeline/raw/",
lang="en",
mineru_backend="vlm-vllm-engine",
)
self.knowledge_cleaning_step2 = KBCChunkGeneratorBatch(
split_method="token",
chunk_size=512,
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
)
def forward(self):
self.knowledge_cleaning_step1.run(
storage=self.storage.step(),
)
self.knowledge_cleaning_step2.run(
storage=self.storage.step(),
)
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
vllm_max_tokens=2048,
vllm_tensor_parallel_size=4,
vllm_gpu_memory_utilization=0.6,
vllm_repetition_penalty=1.2
)
self.knowledge_cleaning_step3 = KBCTextCleanerBatch(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step4 = KBCMultiHopQAGeneratorBatch(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step3.run(
storage=self.storage.step(),
)
self.knowledge_cleaning_step4.run(
storage=self.storage.step(),
)
if __name__ == "__main__":
model = KBCleaning_batchvllm_GPUPipeline()
model.forward()
from dataflow.operators.knowledge_cleaning import (
KBCChunkGenerator,
FileOrURLToMarkdownConverterBatch,
KBCTextCleaner,
# KBCMultiHopQAGenerator,
)
from dataflow.operators.core_text import Text2MultiHopQAGenerator
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
class KBCleaning_PDFSglang_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../../example_data/KBCleaningPipeline/kbc_test.jsonl",
cache_path="./.cache/gpu",
file_name_prefix="knowledge_cleaning_step_sglang_engine",
cache_type="json",
)
self.knowledge_cleaning_step1 = FileOrURLToMarkdownConverterBatch(
intermediate_dir="../../example_data/KBCleaningPipeline/raw/",
lang="en",
mineru_backend="vlm-vllm-engine",
)
self.knowledge_cleaning_step2 = KBCChunkGenerator(
split_method="token",
chunk_size=512,
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
)
def forward(self):
self.knowledge_cleaning_step1.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
self.knowledge_cleaning_step2.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
self.llm_serving = LocalModelLLMServing_sglang(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
sgl_dp_size=1, # data parallel size
sgl_tp_size=1, # tensor parallel size
sgl_max_new_tokens=2048,
)
self.knowledge_cleaning_step3 = KBCTextCleaner(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step4 = Text2MultiHopQAGenerator(
llm_serving=self.llm_serving,
lang="en",
num_q = 5
)
self.knowledge_cleaning_step3.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
self.knowledge_cleaning_step4.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
if __name__ == "__main__":
model = KBCleaning_PDFSglang_GPUPipeline()
model.forward()
\ No newline at end of file
from dataflow.operators.knowledge_cleaning import (
KBCChunkGenerator,
FileOrURLToMarkdownConverterBatch,
KBCTextCleaner,
# KBCMultiHopQAGenerator,
)
from dataflow.operators.core_text import Text2MultiHopQAGenerator
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm
class KBCleaning_PDFvllm_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../../example_data/KBCleaningPipeline/kbc_test.jsonl",
cache_path="./.cache/gpu",
file_name_prefix="knowledge_cleaning_step_vllm_engine",
cache_type="json",
)
self.knowledge_cleaning_step1 = FileOrURLToMarkdownConverterBatch(
intermediate_dir="../../example_data/KBCleaningPipeline/raw/",
lang="en",
mineru_backend="vlm-vllm-engine",
)
self.knowledge_cleaning_step2 = KBCChunkGenerator(
split_method="token",
chunk_size=512,
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
)
def forward(self):
self.knowledge_cleaning_step1.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
self.knowledge_cleaning_step2.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
vllm_max_tokens=2048,
vllm_tensor_parallel_size=4,
vllm_gpu_memory_utilization=0.6,
vllm_repetition_penalty=1.2
)
self.knowledge_cleaning_step3 = KBCTextCleaner(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step4 = Text2MultiHopQAGenerator(
llm_serving=self.llm_serving,
lang="en",
num_q = 5
)
self.knowledge_cleaning_step3.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
self.knowledge_cleaning_step4.run(
storage=self.storage.step(),
# input_key=
# output_key=
)
if __name__ == "__main__":
model = KBCleaning_PDFvllm_GPUPipeline()
model.forward()
\ No newline at end of file
from dataflow.operators.reasoning import (
ReasoningCategoryDatasetEvaluator,
ReasoningDifficultyDatasetEvaluator,
ReasoningQuestionGenerator,
ReasoningAnswerGenerator,
)
from dataflow.operators.reasoning import (
ReasoningQuestionFilter,
ReasoningAnswerPipelineRootFilter,
ReasoningAnswerFormatterFilter,
ReasoningAnswerTokenLengthFilter,
ReasoningAnswerGroundTruthFilter,
ReasoningAnswerNgramFilter,
)
from dataflow.prompts.reasoning.math import (
MathQuestionFilterPrompt,
MathAnswerGeneratorPrompt,
MathQuestionSynthesisPrompt
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
class ReasoningMath_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/ReasoningPipeline/pipeline_math_short.json",
cache_path="./cache_local",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
# use vllm as LLM serving
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=1,
vllm_max_tokens=8192,
)
# use SGLang as LLM serving
# llm_serving = LocalModelLLMServing_sglang(
# hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
# sgl_dp_size=1, # data parallel size
# sgl_tp_size=1, # tensor parallel size
# sgl_max_tokens=1024,
# sgl_tensor_parallel_size=4
# )
self.question_filter_step1 = ReasoningQuestionFilter(
system_prompt="You are an expert in evaluating mathematical problems. Follow the user's instructions strictly and output your final judgment in the required JSON format.",
llm_serving=self.llm_serving,
prompt_template=MathQuestionFilterPrompt()
)
self.question_gen_step2 = ReasoningQuestionGenerator(
num_prompts=3,
llm_serving=self.llm_serving,
prompt_template=MathQuestionSynthesisPrompt()
)
self.question_filter_step3 = ReasoningQuestionFilter(
system_prompt="You are an expert in evaluating mathematical problems. Follow the user's instructions strictly and output your final judgment in the required JSON format.",
llm_serving=self.llm_serving,
prompt_template=MathQuestionFilterPrompt()
)
self.question_difficulty_classifier_step4 = ReasoningDifficultyDatasetEvaluator(
llm_serving=self.llm_serving
)
self.question_category_classifier_step5 = ReasoningCategoryDatasetEvaluator(
llm_serving=self.llm_serving
)
########################## branch ############################
# self.answer_pipeline_root_step6 = AnswerPipelineRoot()
########################## answer ############################
self.answer_generator_step7 = ReasoningAnswerGenerator(
llm_serving=self.llm_serving,
prompt_template=MathAnswerGeneratorPrompt()
)
self.answer_format_filter_step8 = ReasoningAnswerFormatterFilter()
self.answer_token_length_filter_step9 = ReasoningAnswerTokenLengthFilter(
max_answer_token_length = 8192,
tokenizer_dir = "Qwen/Qwen2.5-0.5B-Instruct"
)
self.answer_groundtruth_filter_step10 = ReasoningAnswerGroundTruthFilter()
self.answer_ngram_filter_step11 = ReasoningAnswerNgramFilter(
min_score = 0.1,
max_score = 1.0,
ngrams = 5
)
def forward(self):
self.question_filter_step1.run(
storage = self.storage.step(),
input_key = "instruction",
)
self.question_gen_step2.run(
storage = self.storage.step(),
input_key = "instruction",
)
self.question_filter_step3.run(
storage = self.storage.step(),
input_key = "instruction",
)
self.question_difficulty_classifier_step4.run(
storage = self.storage.step(),
input_key = "instruction",
output_key = "question_difficulty"
)
self.question_category_classifier_step5.run(
storage = self.storage.step(),
input_key = "instruction",
output_key = "question_category"
)
############# branch #############
# self.answer_pipeline_root_step6.run(
# storage = self.storage.step(),
# input_answer_key = "output",
# input_gt_key = "golden_answer"
# )
############## answer #############
self.answer_generator_step7.run(
storage = self.storage.step(),
input_key = "instruction",
output_key = "generated_cot"
)
self.answer_format_filter_step8.run(
storage = self.storage.step(),
input_key = "generated_cot",
)
self.answer_token_length_filter_step9.run(
storage = self.storage.step(),
input_key = "generated_cot"
)
self.answer_groundtruth_filter_step10.run(
storage = self.storage.step(),
input_test_answer_key = "generated_cot",
input_gt_answer_key = "golden_answer"
)
self.answer_ngram_filter_step11.run(
storage = self.storage.step(),
input_question_key = "instruction",
input_answer_key = "generated_cot"
)
if __name__ == "__main__":
model = ReasoningMath_GPUPipeline()
model.forward()
from dataflow.operators.core_speech import Speech2TextGenerator
from dataflow.serving import LocalModelLALMServing_vllm
from dataflow.utils.storage import FileStorage
class SpeechTranscription_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/SpeechTranscription/pipeline_speechtranscription.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.llm_serving = LocalModelLALMServing_vllm(
hf_model_name_or_path='Qwen/Qwen2-Audio-7B-Instruct',
vllm_tensor_parallel_size=4,
vllm_max_tokens=8192,
)
self.speech_transcriptor = Speech2TextGenerator(
llm_serving = self.llm_serving,
system_prompt="你是一个专业的翻译员,你需要将语音转录为文本。"
)
def forward(self):
self.speech_transcriptor.run(
storage=self.storage.step(),
input_key="raw_content"
)
if __name__ == "__main__":
pipeline = SpeechTranscription_GPUPipeline()
pipeline.forward()
\ No newline at end of file
import os
from dataflow import get_logger
import zipfile
from pathlib import Path
from huggingface_hub import snapshot_download
from dataflow.operators.text2sql import (
SQLGenerator,
Text2SQLQuestionGenerator,
Text2SQLPromptGenerator,
Text2SQLCoTGenerator
)
from dataflow.operators.text2sql import (
SQLExecutionFilter
)
from dataflow.operators.text2sql import (
SQLComponentClassifier,
SQLExecutionClassifier
)
from dataflow.prompts.text2sql import (
Text2SQLCotGeneratorPrompt,
SelectSQLGeneratorPrompt,
Text2SQLQuestionGeneratorPrompt,
Text2SQLPromptGeneratorPrompt
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
from dataflow.utils.text2sql.database_manager import DatabaseManager
def download_and_extract_database(logger):
dataset_repo_id = "Open-Dataflow/dataflow-Text2SQL-database-example"
local_dir = "./hf_cache"
extract_to = "./downloaded_databases"
logger.info(f"Downloading and extracting database from {dataset_repo_id}...")
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.makedirs(local_dir, exist_ok=True)
os.makedirs(extract_to, exist_ok=True)
downloaded_path = snapshot_download(
repo_id=dataset_repo_id,
repo_type="dataset",
local_dir=local_dir,
resume_download=True
)
logger.info(f"Files downloaded to: {downloaded_path}")
zip_path = os.path.join(downloaded_path, "databases.zip")
if os.path.exists(zip_path):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
logger.info(f"Database files extracted to {extract_to}")
return extract_to
else:
raise FileNotFoundError(f"Database zip file not found at {zip_path}")
class Text2SQLGeneration_GPUPipeline():
def __init__(self, db_root_path=""):
self.logger = get_logger()
self.db_root_path = db_root_path
if not db_root_path:
try:
self.db_root_path = download_and_extract_database(self.logger)
self.logger.info(f"Using automatically downloaded database at: {self.db_root_path}")
except Exception as e:
self.logger.error(f"Failed to auto-download database: {e}")
raise
else:
self.logger.info(f"Using manually specified database path: {self.db_root_path}")
if not os.path.exists(self.db_root_path):
raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}")
self.storage = FileStorage(
first_entry_file_name="",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=1,
vllm_max_tokens=8192,
)
# use SGLang as LLM serving
# llm_serving = LocalModelLLMServing_sglang(
# hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
# sgl_dp_size=1, # data parallel size
# sgl_tp_size=1, # tensor parallel size
# sgl_max_tokens=1024,
# sgl_tensor_parallel_size=4
# )
# It is recommended to use better LLMs for the generation of Chain-of-Thought (CoT) reasoning process.
cot_generation_llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=1,
vllm_max_tokens=8192,
)
embedding_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Alibaba-NLP/gte-Qwen2-7B-instruct",
vllm_max_tokens=8192
)
# SQLite and MySQL are currently supported
# db_type can be sqlite or mysql, which must match your database type
# If sqlite is selected, root_path must be provided, this path must exist and contain database files
# If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions
# MySQL example:
# database_manager = DatabaseManager(
# db_type="mysql",
# config={
# "host": "localhost",
# "user": "root",
# "password": "your_password",
# "database": "your_database_name"
# }
# )
# SQLite example:
database_manager = DatabaseManager(
db_type="sqlite",
config={
"root_path": self.db_root_path
},
logger=None,
max_connections_per_db=100,
max_workers=100
)
self.sql_generator_step1 = SQLGenerator(
llm_serving=self.llm_serving,
database_manager=database_manager,
generate_num=10,
prompt_template=SelectSQLGeneratorPrompt()
)
self.sql_execution_filter_step2 = SQLExecutionFilter(
database_manager=database_manager,
)
self.text2sql_question_generator_step3 = Text2SQLQuestionGenerator(
llm_serving=self.llm_serving,
embedding_serving=embedding_serving,
database_manager=database_manager,
question_candidates_num=5,
prompt_template=Text2SQLQuestionGeneratorPrompt()
)
self.text2sql_prompt_generator_step4 = Text2SQLPromptGenerator(
database_manager=database_manager,
prompt_template=Text2SQLPromptGeneratorPrompt()
)
self.sql_cot_generator_step5 = Text2SQLCoTGenerator(
llm_serving=cot_generation_llm_serving,
database_manager=database_manager,
max_retries=3,
enable_retry=True,
prompt_template=Text2SQLCotGeneratorPrompt()
)
self.sql_component_classifier_step6 = SQLComponentClassifier(
difficulty_thresholds=[2, 4, 6],
difficulty_labels=['easy', 'medium', 'hard', 'extra']
)
self.sql_execution_classifier_step7 = SQLExecutionClassifier(
llm_serving=self.llm_serving,
database_manager=database_manager,
num_generations=10,
difficulty_thresholds=[2, 5, 9],
difficulty_labels=['extra', 'hard', 'medium', 'easy']
)
def forward(self):
sql_key = "SQL"
db_id_key = "db_id"
question_key = "question"
self.sql_generator_step1.run(
storage=self.storage.step(),
output_sql_key=sql_key,
output_db_id_key=db_id_key
)
self.sql_execution_filter_step2.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key
)
self.text2sql_question_generator_step3.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key,
output_question_key=question_key
)
self.text2sql_prompt_generator_step4.run(
storage=self.storage.step(),
input_question_key=question_key,
input_db_id_key=db_id_key,
output_prompt_key="prompt"
)
self.sql_cot_generator_step5.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_question_key=question_key,
input_db_id_key=db_id_key,
output_cot_key="cot_reasoning"
)
self.sql_component_classifier_step6.run(
storage=self.storage.step(),
input_sql_key=sql_key,
output_difficulty_key="sql_component_difficulty"
)
self.sql_execution_classifier_step7.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key,
input_prompt_key="prompt",
output_difficulty_key="sql_execution_difficulty"
)
if __name__ == "__main__":
# If you have your own database files, you can set the db_root_path to the path of your database files
# If not, please set the db_root_path "", and we will download the example database files automatically
db_root_path = ""
model = Text2SQLGeneration_GPUPipeline(db_root_path=db_root_path)
model.forward()
\ No newline at end of file
import os
from dataflow import get_logger
import zipfile
from huggingface_hub import snapshot_download
from dataflow.operators.text2sql import (
SQLVariationGenerator,
Text2SQLQuestionGenerator,
Text2SQLPromptGenerator,
Text2SQLCoTGenerator
)
from dataflow.operators.text2sql import (
SQLExecutionFilter,
SQLConsistencyFilter
)
from dataflow.operators.text2sql import (
SQLComponentClassifier,
SQLExecutionClassifier
)
from dataflow.prompts.text2sql import (
SQLConsistencyFilterPrompt,
Text2SQLCotGeneratorPrompt,
Text2SQLQuestionGeneratorPrompt,
SQLVariationGeneratorPrompt,
Text2SQLPromptGeneratorPrompt
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
from dataflow.utils.text2sql.database_manager import DatabaseManager
def download_and_extract_database(logger):
dataset_repo_id = "Open-Dataflow/dataflow-Text2SQL-database-example"
local_dir = "./hf_cache"
extract_to = "./downloaded_databases"
logger.info(f"Downloading and extracting database from {dataset_repo_id}...")
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.makedirs(local_dir, exist_ok=True)
os.makedirs(extract_to, exist_ok=True)
downloaded_path = snapshot_download(
repo_id=dataset_repo_id,
repo_type="dataset",
local_dir=local_dir,
resume_download=True
)
logger.info(f"Files downloaded to: {downloaded_path}")
zip_path = os.path.join(downloaded_path, "databases.zip")
if os.path.exists(zip_path):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
logger.info(f"Database files extracted to {extract_to}")
return extract_to
else:
raise FileNotFoundError(f"Database zip file not found at {zip_path}")
class Text2SQLRefine_GPUPipeline():
def __init__(self, db_root_path=""):
self.logger = get_logger()
self.db_root_path = db_root_path
if not db_root_path:
try:
self.db_root_path = download_and_extract_database(self.logger)
self.logger.info(f"Using automatically downloaded database at: {self.db_root_path}")
except Exception as e:
self.logger.error(f"Failed to auto-download database: {e}")
raise
else:
self.logger.info(f"Using manually specified database path: {self.db_root_path}")
if not os.path.exists(self.db_root_path):
raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}")
self.storage = FileStorage(
first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl",
cache_path="./cache_local",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl"
)
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=1,
vllm_max_tokens=8192,
)
# It is recommended to use better LLMs for the generation of Chain-of-Thought (CoT) reasoning process.
cot_generation_llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct", # set to your own model path
vllm_tensor_parallel_size=1,
vllm_max_tokens=8192,
)
embedding_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Alibaba-NLP/gte-Qwen2-7B-instruct",
vllm_max_tokens=8192
)
# SQLite and MySQL are currently supported
# db_type can be sqlite or mysql, which must match your database type
# If sqlite is selected, root_path must be provided, this path must exist and contain database files
# If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions
# MySQL example:
# database_manager = DatabaseManager(
# db_type="mysql",
# config={
# "host": "localhost",
# "user": "root",
# "password": "your_password",
# "database": "your_database_name"
# }
# )
# SQLite example:
database_manager = DatabaseManager(
db_type="sqlite",
config={
"root_path": self.db_root_path
},
logger=None,
sql_execution_timeout = 2,
max_connections_per_db=100,
max_workers=100
)
self.sql_execution_filter_step1 = SQLExecutionFilter(
database_manager=database_manager
)
self.sql_consistency_filter_step2 = SQLConsistencyFilter(
llm_serving=self.llm_serving,
database_manager=database_manager,
prompt_template=SQLConsistencyFilterPrompt()
)
self.sql_variation_generator_step3 = SQLVariationGenerator(
llm_serving=self.llm_serving,
database_manager=database_manager,
num_variations=5,
prompt_template=SQLVariationGeneratorPrompt()
)
self.sql_execution_filter_step4 = SQLExecutionFilter(
database_manager=database_manager
)
self.text2sql_question_generator_step5 = Text2SQLQuestionGenerator(
llm_serving=self.llm_serving,
embedding_serving=embedding_serving,
database_manager=database_manager,
question_candidates_num=5,
prompt_template=Text2SQLQuestionGeneratorPrompt()
)
self.text2sql_prompt_generator_step6 = Text2SQLPromptGenerator(
database_manager=database_manager,
prompt_template=Text2SQLPromptGeneratorPrompt()
)
self.sql_cot_generator_step7 = Text2SQLCoTGenerator(
llm_serving=cot_generation_llm_serving,
database_manager=database_manager,
max_retries=3,
enable_retry=True,
prompt_template=Text2SQLCotGeneratorPrompt()
)
self.sql_component_classifier_step8 = SQLComponentClassifier(
num_generations = 10,
difficulty_thresholds = [2, 4, 6],
difficulty_labels = ['easy', 'medium', 'hard', 'extra']
)
self.sql_execution_classifier_step9 = SQLExecutionClassifier(
llm_serving=self.llm_serving,
database_manager=database_manager,
num_generations = 10,
difficulty_thresholds = [2, 5, 9],
difficulty_labels = ['extra', 'hard', 'medium', 'easy']
)
def forward(self):
sql_key = "SQL"
db_id_key = "db_id"
question_key = "question"
self.sql_execution_filter_step1.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key
)
self.sql_consistency_filter_step2.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key,
input_question_key=question_key
)
self.sql_variation_generator_step3.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key
)
self.sql_execution_filter_step4.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key
)
self.text2sql_question_generator_step5.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key,
output_question_key=question_key
)
self.text2sql_prompt_generator_step6.run(
storage=self.storage.step(),
input_question_key=question_key,
input_db_id_key=db_id_key,
output_prompt_key="prompt"
)
self.sql_cot_generator_step7.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_question_key=question_key,
input_db_id_key=db_id_key,
output_cot_key="cot_reasoning"
)
self.sql_component_classifier_step8.run(
storage=self.storage.step(),
input_sql_key=sql_key,
output_difficulty_key="sql_component_difficulty"
)
self.sql_execution_classifier_step9.run(
storage=self.storage.step(),
input_sql_key=sql_key,
input_db_id_key=db_id_key,
input_prompt_key="prompt",
output_difficulty_key="sql_execution_difficulty"
)
if __name__ == "__main__":
# If you have your own database files, you can set the db_root_path to the path of your database files
# If not, please set the db_root_path "", and we will download the example database files automatically
db_root_path = ""
model = Text2SQLRefine_GPUPipeline(db_root_path=db_root_path)
model.forward()
\ No newline at end of file
from dataflow.operators.general_text import (
MinHashDeduplicateFilter,
LanguageFilter,
WordNumberFilter,
BlocklistFilter,
ColonEndFilter,
SentenceNumberFilter,
LineEndWithEllipsisFilter,
ContentNullFilter,
MeanWordLengthFilter,
SymbolWordRatioFilter,
HtmlEntityFilter,
NoPuncFilter,
SpecialCharacterFilter,
WatermarkFilter,
CurlyBracketFilter,
CapitalWordsFilter,
LoremIpsumFilter,
UniqueWordsFilter,
CharNumberFilter,
LineStartWithBulletpointFilter,
LineWithJavascriptFilter,
RemoveExtraSpacesRefiner,
RemoveEmojiRefiner,
HtmlUrlRemoverRefiner,
)
from dataflow.operators.text_pt import (
PairQualFilter
)
from dataflow.utils.storage import FileStorage
class PTTextFilter_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/GeneralTextPipeline/pt_input.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.model_cache_dir = './dataflow_cache'
self.language_filter = LanguageFilter(allowed_languages = '__label__eng_Latn', model_cache_dir = self.model_cache_dir)
self.remove_extra_spaces_refiner = RemoveExtraSpacesRefiner()
self.remove_emoji_refiner = RemoveEmojiRefiner()
self.html_remove_refiner = HtmlUrlRemoverRefiner()
self.minhash_deduplicator = MinHashDeduplicateFilter(num_perm=128, threshold=0.9, use_n_gram=True, ngram=5)
self.blocklist_filter = BlocklistFilter()
self.word_number_filter = WordNumberFilter(min_words=20, max_words=100000)
self.colon_end_filter = ColonEndFilter()
self.sentence_number_filter = SentenceNumberFilter(min_sentences=3, max_sentences=7500)
self.line_end_with_ellipsis_filter = LineEndWithEllipsisFilter(threshold=0.3)
self.content_null_filter = ContentNullFilter()
self.mean_word_length_filter = MeanWordLengthFilter(min_length=3, max_length=10)
self.symbol_word_ratio_filter = SymbolWordRatioFilter(threshold=0.4)
self.html_entity_filter = HtmlEntityFilter()
self.no_punc_filter = NoPuncFilter(threshold=112)
self.special_character_filter = SpecialCharacterFilter()
self.watermark_filter = WatermarkFilter(watermarks=['Copyright', 'Watermark', 'Confidential'])
self.curly_bracket_filter = CurlyBracketFilter(threshold=0.025)
self.capital_words_filter = CapitalWordsFilter(threshold=0.2, use_tokenizer=False)
self.lorem_ipsum_filter = LoremIpsumFilter(threshold=3e-8)
self.unique_words_filter = UniqueWordsFilter(threshold=0.1)
self.char_number_filter = CharNumberFilter(threshold=100)
self.line_start_with_bulletpoint_filter = LineStartWithBulletpointFilter(threshold=0.9)
self.line_with_javascript_filter = LineWithJavascriptFilter(threshold=3)
self.quality_filter = PairQualFilter(min_score=2, max_score=10000, lang='en')
def forward(self):
self.remove_emoji_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.html_remove_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.remove_extra_spaces_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.blocklist_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.word_number_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.colon_end_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.sentence_number_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.line_end_with_ellipsis_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.content_null_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.mean_word_length_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.symbol_word_ratio_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.html_entity_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.no_punc_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.special_character_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.watermark_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.curly_bracket_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.capital_words_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.lorem_ipsum_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.unique_words_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.char_number_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.line_start_with_bulletpoint_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.line_with_javascript_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.quality_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.language_filter.run(
storage = self.storage.step(),
input_key = "raw_content"
)
self.minhash_deduplicator.run(
storage = self.storage.step(),
input_key='raw_content',
)
if __name__ == "__main__":
# This is the entry point for the pipeline
model = PTTextFilter_GPUPipeline()
model.forward()
from dataflow.operators.general_text import (
MinHashDeduplicateFilter,
LanguageFilter,
WordNumberFilter,
BlocklistFilter,
ColonEndFilter,
SentenceNumberFilter,
LineEndWithEllipsisFilter,
ContentNullFilter,
MeanWordLengthFilter,
SymbolWordRatioFilter,
HtmlEntityFilter,
IDCardFilter,
NoPuncFilter,
SpecialCharacterFilter,
WatermarkFilter,
CurlyBracketFilter,
CapitalWordsFilter,
LoremIpsumFilter,
UniqueWordsFilter,
CharNumberFilter,
LineStartWithBulletpointFilter,
LineWithJavascriptFilter,
RemoveExtraSpacesRefiner,
RemoveEmojiRefiner,
HtmlUrlRemoverRefiner,
)
from dataflow.operators.text_pt import (
PairQualFilter,
QuratingFilter
)
from dataflow.operators.text_pt import Phi4QAGenerator
from dataflow.serving import LocalModelLLMServing_vllm, LocalModelLLMServing_sglang
from dataflow.utils.storage import FileStorage
class PTTextSynthetic_GPUPipeline():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="../example_data/GeneralTextPipeline/pt_input.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.model_cache_dir = './dataflow_cache'
# use local model as LLM serving
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path='Qwen/Qwen2.5-7B-Instruct',
vllm_tensor_parallel_size=1,
vllm_max_tokens=8192,
)
# use SGLang as LLM serving
# self.llm_serving = LocalModelLLMServing_sglang(
# hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
# sgl_dp_size=1, # data parallel size
# sgl_tp_size=1, # tensor parallel size
# sgl_max_tokens=1024,
# sgl_tensor_parallel_size=4
# )
self.language_filter = LanguageFilter(allowed_languages = '__label__eng_Latn', model_cache_dir = self.model_cache_dir)
self.remove_extra_spaces_refiner = RemoveExtraSpacesRefiner()
self.remove_emoji_refiner = RemoveEmojiRefiner()
self.html_remove_refiner = HtmlUrlRemoverRefiner()
self.minhash_deduplicator = MinHashDeduplicateFilter(num_perm=128, threshold=0.9, use_n_gram=True, ngram=5)
self.blocklist_filter = BlocklistFilter()
self.word_number_filter = WordNumberFilter(min_words=20, max_words=100000)
self.colon_end_filter = ColonEndFilter()
self.sentence_number_filter = SentenceNumberFilter(min_sentences=3, max_sentences=7500)
self.line_end_with_ellipsis_filter = LineEndWithEllipsisFilter(threshold=0.3)
self.content_null_filter = ContentNullFilter()
self.mean_word_length_filter = MeanWordLengthFilter(min_length=3, max_length=10)
self.symbol_word_ratio_filter = SymbolWordRatioFilter(threshold=0.4)
self.html_entity_filter = HtmlEntityFilter()
self.id_card_filter = IDCardFilter(threshold=3)
self.no_punc_filter = NoPuncFilter(threshold=112)
self.special_character_filter = SpecialCharacterFilter()
self.watermark_filter = WatermarkFilter(watermarks=['Copyright', 'Watermark', 'Confidential'])
self.curly_bracket_filter = CurlyBracketFilter(threshold=0.025)
self.capital_words_filter = CapitalWordsFilter(threshold=0.2, use_tokenizer=False)
self.lorem_ipsum_filter = LoremIpsumFilter(threshold=3e-8)
self.unique_words_filter = UniqueWordsFilter(threshold=0.1)
self.char_number_filter = CharNumberFilter(threshold=100)
self.line_start_with_bulletpoint_filter = LineStartWithBulletpointFilter(threshold=0.9)
self.line_with_javascript_filter = LineWithJavascriptFilter(threshold=3)
self.quality_filter = PairQualFilter(min_score=-2, max_score=10000, lang='en')
self.pt_generator = Phi4QAGenerator(
llm_serving=self.llm_serving
)
self.qurating_filter = QuratingFilter(min_scores = {'writing_style': 0,'required_expertise': 0,'facts_and_trivia': 0,'educational_value': 0}, max_scores = {'writing_style': 9,'required_expertise': 9,'facts_and_trivia': 9,'educational_value': 9})
def forward(self):
# Initial filters
self.language_filter.run(
storage = self.storage.step(),
input_key = "raw_content"
)
# refiners
self.remove_extra_spaces_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.remove_emoji_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.html_remove_refiner.run(
storage=self.storage.step(),
input_key="raw_content"
)
self.minhash_deduplicator.run(
storage = self.storage.step(),
input_key='raw_content',
output_key='minhash_deduplicated_label',
)
self.blocklist_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.word_number_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.colon_end_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.sentence_number_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
self.line_end_with_ellipsis_filter.run(
storage = self.storage.step(),
input_key = 'raw_content'
)
# Add the additional filters here
self.content_null_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.mean_word_length_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.symbol_word_ratio_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.html_entity_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.id_card_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.no_punc_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.special_character_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.watermark_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.curly_bracket_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.capital_words_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.lorem_ipsum_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.unique_words_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.char_number_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.line_start_with_bulletpoint_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.line_with_javascript_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.quality_filter.run(
storage = self.storage.step(),
input_key='raw_content',
)
self.pt_generator.run(
storage=self.storage.step(),
input_key='raw_content',
output_key='generated_content'
)
self.qurating_filter.run(
storage=self.storage.step(),
input_key='generated_content'
)
if __name__ == "__main__":
model = PTTextSynthetic_GPUPipeline()
model.forward()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment