Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
import shutil
from colorama import init, Fore, Style
from pathlib import Path
def copy_file(source: Path, destination: Path):
"""
Copy a single file with path, with checking and asking about the existence.
"""
if not source.is_file():
print(f"Error: {source} is not a file.")
return
if destination.exists():
print(f"Warning: {destination.name} already exists in {destination.parent}.")
user_input = input(f"Do you want to overwrite {destination.name}? (y/n): ").strip().lower()
if user_input != 'y':
print(f"Skipping {destination.name}.")
return
shutil.copy(source, destination)
print(f"Copied {source.name} to {destination}.")
def copy_files_without_recursion(source: Path, destination):
"""
Copy files under a path without recursion, with checking and asking about the existence.
"""
yes_to_all, none_to_all = False, False
for template_file in source.iterdir():
if template_file.is_file():
if template_file.name == "__init__.py":
continue # skip __init__.py files
destination_file = Path(destination) / template_file.name
if destination_file.exists():
if none_to_all:
print(f' Skipping {template_file.name}.\n')
continue
if not yes_to_all:
# Alert , whether overwrite?
print(f' {Fore.YELLOW}Warning: {template_file.name} already exists in {destination}.{Style.RESET_ALL}')
user_input = input(f' Do you want to overwrite {template_file.name}? (y/n/all/none): ').strip().lower()
if user_input == 'all':
yes_to_all = True
elif user_input == 'none':
none_to_all = True
print(f' Skipping {template_file.name}.\n')
continue
elif user_input != 'y':
print(f' Skipping {template_file.name}.\n')
continue
shutil.copy(template_file, destination_file)
print(f' Copied {template_file.name} to {destination}.\n')
def copy_files_recursively(source_path: Path, destination_path: Path):
"""
Recursively copy all contents from source_path to destination_path.
Prompts user if a file already exists in destination.
"""
if not source_path.exists():
print(f"{Fore.RED}Error: Source path does not exist.{Style.RESET_ALL}")
return
if not source_path.is_dir():
print(f"{Fore.RED}Error: Source path is not a directory.{Style.RESET_ALL}")
return
yes_to_all = False
none_to_all = False
for item in source_path.rglob('*'):
relative_path = item.relative_to(source_path)
dest_item = destination_path / relative_path
if item.is_dir():
dest_item.mkdir(parents=True, exist_ok=True)
else:
if dest_item.exists():
if yes_to_all:
pass # proceed with overwrite
elif none_to_all:
print(f' Skipping {dest_item.name}.\n')
continue
else:
print(f' {Fore.YELLOW}Warning: {dest_item.name} already exists in {destination_path}.{Style.RESET_ALL}')
user_input = input(f' Do you want to overwrite {dest_item.name}? (y/n/all/none): ').strip().lower()
if user_input == 'all':
yes_to_all = True
elif user_input == 'none':
none_to_all = True
print(f' Skipping {dest_item.name}.\n')
continue
elif user_input != 'y':
print(f' Skipping {dest_item.name}.\n')
continue
shutil.copy2(item, dest_item)
# give a clear output with multi line
print(f'{Fore.GREEN}[Copied]\nFrom: {item}\nTo: {dest_item}{Style.RESET_ALL}\n')
\ No newline at end of file
# eval_api.py - API评估配置文件
"""DataFlow API Evaluation Configuration - Enhanced Version"""
import os
from dataflow.operators.core_text import BenchDatasetEvaluatorQuestion
from dataflow.serving import APILLMServing_request
from dataflow.utils.storage import FileStorage
# =============================================================================
# Fair Evaluation Prompt Template
# =============================================================================
class FairAnswerJudgePrompt:
"""Fair answer evaluation prompt template with English prompts"""
# 默认评估模型提示词 该prompt为评估模型的提示词,请勿与被评估模型提示词混淆
def build_prompt(self, question, answer, reference_answer):
prompt = f"""You are an expert evaluator assessing answer quality for academic questions.
**Question:**
{question}
**Answer to Evaluate:**
{answer}
**Evaluation Instructions:**
Judge this answer based on:
1. **Factual Accuracy**: Is the information correct?
2. **Completeness**: Does it address the key aspects of the question?
3. **Relevance**: Is it directly related to what was asked?
4. **Academic Quality**: Is the reasoning sound and appropriate?
**Important Guidelines:**
- Focus on content correctness, not writing style
- A good answer may be longer, shorter, or differently structured
- Accept different valid approaches or explanations
- Judge based on whether the answer demonstrates correct understanding
- Consider partial credit for answers that are mostly correct
**Reference Answer (for context only):** {reference_answer}
**Output Format:**
Return your judgment in JSON format:
{{"judgement_result": true}} if the answer is factually correct and adequately addresses the question
{{"judgement_result": false}} if the answer contains significant errors or fails to address the question
**Your Judgment:**"""
return prompt
# =============================================================================
# Configuration Parameters
# 参数设置
# =============================================================================
# Judge Model Configuration (API model as judge)
# 评估模型设置
JUDGE_MODEL_CONFIG = {
"model_name": "gpt-4o-mini",
"api_url": "", # 请求URL 必填 / request (required)
"api_key_env": "DF_API_KEY", # api_key 必填 / api_key (required)
"max_workers": 3,
"max_retries": 5,
}
# Target Models Configuration (List format - required, each element is a dict)
# 被评估模型设置 (列表格式 - 必需,每个元素是字典)
TARGET_MODELS = [
# {
# "name": "qwen_3b", # 模型名称(可选,默认使用路径最后一部分) / Model name (optional, uses the last part of the path by default)
# "path": "./Qwen2.5-3B-Instruct", # 模型路径(必需) / Model path (required)
# # ===== 答案生成的模型加载参数(可选)=====
# "tensor_parallel_size": 1, # GPU并行数量 / Number of GPU parallels
# "max_tokens": 1024, # 最大生成token数 / Maximum number of generated tokens
# "gpu_memory_utilization": 0.8, # GPU显存利用率 / GPU memory utilization
# },
{
"name": "qwen_7b",
"path": "./Qwen2.5-7B-Instruct",
# 大模型可以用不同的参数
"tensor_parallel_size": 2,
"max_tokens": 2048,
"gpu_memory_utilization": 0.9,
# 可以为每个模型自定义提示词 不写就为默认模板 即build_prompt函数中的prompt
# 默认被评估模型提示词
# 再次提示:该prompt为被评估模型的提示词,请勿与评估模型提示词混淆!!!
# You can customize prompts for each model. If not specified, defaults to the template in build_prompt function.
# Default prompt for evaluated models
# IMPORTANT: This is the prompt for models being evaluated, NOT for the judge model!!!
"answer_prompt": """please answer the following question:""" # 这里不要使用{question} / do not code {question} here
},
# 添加更多模型...
# {
# "name": "llama_8b",
# "path": "meta-llama/Llama-3-8B-Instruct",
# "tensor_parallel_size": 2
# }
]
# Data Configuration
DATA_CONFIG = {
"input_file": "./.cache/data/qa.json", # 输入数据文件
"output_dir": "./eval_results", # 输出目录
"question_key": "input", # 原始数据中的问题字段
"reference_answer_key": "output" # 原始数据中的参考答案字段
}
# Evaluator Run Configuration (parameters passed to BenchDatasetEvaluator.run)
EVALUATOR_RUN_CONFIG = {
"input_test_answer_key": "model_generated_answer", # 模型生成的答案字段名
"input_gt_answer_key": "output", # 标准答案字段名(对应原始数据)
"input_question_key": "input" # 问题字段名(对应原始数据)
}
# Evaluation Configuration
EVAL_CONFIG = {
"compare_method": "semantic", # "semantic" 语义匹配 或 "match" 字段完全匹配
}
# =============================================================================
# Component Creation Functions
# =============================================================================
def create_judge_serving():
"""创建评估器LLM服务(API模式)"""
api_key_env = JUDGE_MODEL_CONFIG["api_key_env"]
if api_key_env not in os.environ:
raise ValueError(f"Environment variable {api_key_env} is not set. "
f"Please set it with your API key.")
api_key = os.environ[api_key_env]
if not api_key.strip():
raise ValueError(f"Environment variable {api_key_env} is empty. "
f"Please provide a valid API key.")
return APILLMServing_request(
api_url=JUDGE_MODEL_CONFIG["api_url"],
key_name_of_api_key=api_key_env,
model_name=JUDGE_MODEL_CONFIG["model_name"],
max_workers=JUDGE_MODEL_CONFIG.get("max_workers", 10),
max_retries=JUDGE_MODEL_CONFIG.get("max_retries", 5)
)
def create_evaluator(judge_serving, eval_result_path):
"""创建评估算子"""
return BenchDatasetEvaluatorQuestion(
compare_method=EVAL_CONFIG["compare_method"],
llm_serving=judge_serving,
prompt_template=FairAnswerJudgePrompt(),
eval_result_path=eval_result_path
)
def create_storage(data_file, cache_path):
"""创建存储算子"""
return FileStorage(
first_entry_file_name=data_file,
cache_path=cache_path,
file_name_prefix="eval_result",
cache_type="json"
)
# =============================================================================
# Main Configuration Function
# =============================================================================
def get_evaluator_config():
# 返回完整配置
# Return complete configuration
return {
"JUDGE_MODEL_CONFIG": JUDGE_MODEL_CONFIG, # 评估模型设置映射
"TARGET_MODELS": TARGET_MODELS, # 被评估模型设置映射
"DATA_CONFIG": DATA_CONFIG, # 数据设置映射
"EVAL_CONFIG": EVAL_CONFIG, # 评估模式设置映射
"EVALUATOR_RUN_CONFIG": EVALUATOR_RUN_CONFIG, # 评估数据集字段映射
"create_judge_serving": create_judge_serving,
"create_evaluator": create_evaluator,
"create_storage": create_storage
}
# =============================================================================
# Direct Execution Support
# 直接运行评估
# =============================================================================
if __name__ == "__main__":
# 直接运行时的简单评估
# Simple evaluation when run directly
print("Starting API evaluation...")
from dataflow.cli_funcs.cli_eval import run_evaluation
try:
config = get_evaluator_config()
success = run_evaluation(config)
if success:
print("API evaluation completed successfully")
else:
print("API evaluation failed")
except Exception as e:
print(f"Evaluation error: {e}")
import traceback
traceback.print_exc()
\ No newline at end of file
# eval_local.py - 本地评估配置文件
"""DataFlow Local Evaluation Configuration - Enhanced Version"""
from pathlib import Path
from dataflow.serving import LocalModelLLMServing_vllm
from dataflow.utils.storage import FileStorage
from dataflow.operators.core_text import BenchDatasetEvaluatorQuestion
# =============================================================================
# Fair Evaluation Prompt Template
# =============================================================================
class FairAnswerJudgePrompt:
"""Fair answer evaluation prompt template with English prompts"""
def build_prompt(self, question, answer, reference_answer):
prompt = f"""You are an expert evaluator assessing answer quality for academic questions.
**Question:**
{question}
**Answer to Evaluate:**
{answer}
**Evaluation Instructions:**
Judge this answer based on:
1. **Factual Accuracy**: Is the information correct?
2. **Completeness**: Does it address the key aspects of the question?
3. **Relevance**: Is it directly related to what was asked?
4. **Academic Quality**: Is the reasoning sound and appropriate?
**Important Guidelines:**
- Focus on content correctness, not writing style
- A good answer may be longer, shorter, or differently structured
- Accept different valid approaches or explanations
- Judge based on whether the answer demonstrates correct understanding
- Consider partial credit for answers that are mostly correct
**Reference Answer (for context only):** {reference_answer}
**Output Format:**
Return your judgment in JSON format:
{{"judgement_result": true}} if the answer is factually correct and adequately addresses the question
{{"judgement_result": false}} if the answer contains significant errors or fails to address the question
**Your Judgment:**"""
return prompt
# =============================================================================
# Configuration Parameters
# =============================================================================
# Judge Model Configuration (local strong model as judge)
JUDGE_MODEL_CONFIG = {
"model_path": "./Qwen2.5-7B-Instruct", # 用更强的模型做裁判
"tensor_parallel_size": 1,
"max_tokens": 512,
"gpu_memory_utilization": 0.8,
}
# Target Models Configuration (字典格式 - 必需)
TARGET_MODELS = [
# {
# "name": "qwen_3b", # 模型名称(可选,默认使用路径最后一部分)
# "path": "./Qwen2.5-3B-Instruct", # 模型路径(必需)
# # ===== 答案生成的模型加载参数(可选)=====
# "tensor_parallel_size": 1, # GPU并行数量
# "max_tokens": 1024, # 最大生成token数
# "gpu_memory_utilization": 0.8, # GPU显存利用率
# },
{
"name": "qwen_7b",
"path": "./Qwen2.5-7B-Instruct",
# 大模型可以用不同的参数
"tensor_parallel_size": 2,
"max_tokens": 2048,
"gpu_memory_utilization": 0.9,
# 可以为每个模型自定义提示词
"answer_prompt": """please answer the following question:"""
},
# 添加更多模型...
# {
# "name": "llama_8b",
# "path": "meta-llama/Llama-3-8B-Instruct",
# "tensor_parallel_size": 2
# }
]
# Data Configuration
DATA_CONFIG = {
"input_file": "/data1/fyl/workspace/.cache/data/qa.json", # 输入数据文件
"output_dir": "./eval_results", # 输出目录
"question_key": "input", # 原始数据中的问题字段
"reference_answer_key": "output" # 原始数据中的参考答案字段
}
# Evaluator Run Configuration (parameters passed to BenchDatasetEvaluator.run)
EVALUATOR_RUN_CONFIG = {
"input_test_answer_key": "model_generated_answer", # 模型生成的答案字段名
"input_gt_answer_key": "output", # 标准答案字段名(对应原始数据)
"input_question_key": "input" # 问题字段名(对应原始数据)
}
# Evaluation Configuration
EVAL_CONFIG = {
"compare_method": "semantic", # "semantic" 语义匹配 或 "match" 字段完全匹配
}
# =============================================================================
# Component Creation Functions - DataFlow Style
# =============================================================================
def create_judge_serving():
"""创建本地评估器LLM服务"""
model_path = JUDGE_MODEL_CONFIG["model_path"]
# Enhanced model path validation
if not model_path.startswith(("Qwen", "meta-llama", "microsoft", "google", "huggingface.co")):
model_path_obj = Path(model_path)
if not model_path_obj.exists():
raise FileNotFoundError(f"Local model path does not exist: {model_path}")
# Check for required model files
required_files = ["config.json"]
missing_files = [f for f in required_files if not (model_path_obj / f).exists()]
if missing_files:
raise ValueError(f"Missing required model files in {model_path}: {missing_files}")
# Enhanced VLLM configuration
vllm_config = {
"hf_model_name_or_path": model_path,
"vllm_tensor_parallel_size": JUDGE_MODEL_CONFIG.get("tensor_parallel_size", 1),
"vllm_max_tokens": JUDGE_MODEL_CONFIG.get("max_tokens", 512),
"vllm_gpu_memory_utilization": JUDGE_MODEL_CONFIG.get("gpu_memory_utilization", 0.8)
}
# Add optional VLLM parameters if they exist
optional_params = ["dtype", "trust_remote_code", "enforce_eager", "disable_log_stats"]
for param in optional_params:
if param in JUDGE_MODEL_CONFIG:
vllm_config[f"vllm_{param}"] = JUDGE_MODEL_CONFIG[param]
return LocalModelLLMServing_vllm(**vllm_config)
def create_evaluator(judge_serving, eval_result_path):
"""创建评估算子"""
return BenchDatasetEvaluatorQuestion(
compare_method=EVAL_CONFIG["compare_method"],
llm_serving=judge_serving,
prompt_template=FairAnswerJudgePrompt(),
eval_result_path=eval_result_path
)
def create_storage(data_file, cache_path):
"""创建存储算子"""
return FileStorage(
first_entry_file_name=data_file,
cache_path=cache_path,
file_name_prefix="eval_result",
cache_type="json"
)
# =============================================================================
# Main Configuration Function
# =============================================================================
def get_evaluator_config():
"""返回完整配置"""
return {
"JUDGE_MODEL_CONFIG": JUDGE_MODEL_CONFIG,
"TARGET_MODELS": TARGET_MODELS,
"DATA_CONFIG": DATA_CONFIG,
"EVALUATOR_RUN_CONFIG": EVALUATOR_RUN_CONFIG,
"EVAL_CONFIG": EVAL_CONFIG,
"create_judge_serving": create_judge_serving,
"create_evaluator": create_evaluator,
"create_storage": create_storage
}
# =============================================================================
# Direct Execution Support
# =============================================================================
if __name__ == "__main__":
# 直接运行时的简单评估
print("Starting local evaluation...")
from dataflow.cli_funcs.cli_eval import run_evaluation
try:
config = get_evaluator_config()
success = run_evaluation(config)
if success:
print("Local evaluation completed successfully")
else:
print("Local evaluation failed")
except Exception as e:
print(f"Evaluation error: {e}")
import traceback
traceback.print_exc()
\ No newline at end of file
import os
from pathlib import Path
import appdirs
class DataFlowPath:
"""
Class to manage paths for DataFlow.
"""
@staticmethod
def get_dataflow_dir():
# return path of /dataflow
return Path(__file__).parent.parent
# @staticmethod
# def get_dataflow_scripts_dir():
# return DataFlowPath.get_dataflow_dir() / "scripts"
@staticmethod
def get_dataflow_example_dir():
return DataFlowPath.get_dataflow_dir() / "example"
@staticmethod
def get_dataflow_statics_dir():
return DataFlowPath.get_dataflow_dir() / "statics"
@staticmethod
def get_dataflow_pipelines_dir():
return DataFlowPath.get_dataflow_statics_dir() / "pipelines"
@staticmethod
def get_dataflow_playground_dir():
return DataFlowPath.get_dataflow_statics_dir() / "playground"
# @staticmethod
# def get_dataset_json_dir() -> Path:
# return DataFlowPath.get_dataflow_dir() / "dataset_json"
# @staticmethod
# def get_init_base_dir() -> Path:
# return DataFlowPath.get_dataflow_dir() / "init_base"
# @staticmethod
# def get_model_zoo_runs_dir() -> Path:
# return DataFlowPath.get_dataflow_dir() / "model_zoo" / "runs"
\ No newline at end of file
#!/usr/bin/env python3
"""
LlamaFactory Training Script with YAML Configuration Management
Complete "check-create-read-train" workflow
"""
import yaml
import os
import sys
import subprocess
import argparse
from pathlib import Path
class LlamaFactoryTrainer:
def __init__(self, config_path=None, cache_base="./"):
# 处理cache_base相对路径 - 基于调用者工作目录
cache_path = Path(cache_base)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
self.cache_path = cache_path
# 如果没有指定config_path,使用cache目录下的默认路径
if config_path is None:
config_path = str(cache_path / ".cache" / "train_config.yaml")
else:
# 处理config_path相对路径
config_path_obj = Path(config_path)
if not config_path_obj.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
config_path = str(caller_cwd / config_path_obj)
self.config_path = Path(config_path)
def get_default_config(self):
"""Get default configuration - 基于最新LlamaFactory标准"""
return {
# === 基础配置 ===
"stage": "sft",
"do_train": True,
# === 模型配置 ===
"model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
"template": "qwen",
"trust_remote_code": True,
# === 微调方法配置 ===
"finetuning_type": "lora",
"lora_target": "all", # 使用 "all" 而不是具体层名
"lora_rank": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
# === 数据集配置 ===
"dataset": "kb_qa",
"dataset_dir": str(self.cache_path / ".cache" / "data"),
"cutoff_len": 1024,
"max_samples": None,
"overwrite_cache": True,
"preprocessing_num_workers": 4,
# === 训练配置 ===
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 4,
"learning_rate": 5e-5,
"num_train_epochs": 2.0,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.05,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"optim": "adamw_torch",
# === 输出配置 ===
"output_dir": str(self.cache_path / ".cache" / "saves" / "qwen2.5_7b_sft_model"),
"overwrite_output_dir": True,
"save_only_model": True,
"plot_loss": True,
# === 日志和检查点 ===
"logging_steps": 10,
"save_steps": 300,
"save_total_limit": 2,
# === 评估配置 ===
"val_size": 0.1,
"per_device_eval_batch_size": 2,
"eval_strategy": "steps",
"eval_steps": 300,
# === 硬件配置 ===
"fp16": False,
"bf16": True,
"tf32": True,
"dataloader_num_workers": 2,
# === 其他配置 ===
"seed": 42,
"ddp_timeout": 1800,
"report_to": "none",
"run_name": "qwen2.5_7b_sft_training",
}
def check_and_create_config(self):
"""
Core logic: check config file, read if exists, create if not
"""
# Ensure directory exists
self.config_path.parent.mkdir(parents=True, exist_ok=True)
if self.config_path.exists():
# Config file exists, read it
print(f"Found config file: {self.config_path}")
return self._load_existing_config()
else:
# Config file doesn't exist, create default config
print(f"Creating new config file: {self.config_path}")
return self._create_new_config()
def _load_existing_config(self):
"""Load existing config file"""
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
print("Config file loaded successfully")
# Check if new default parameters need to be added
default_config = self.get_default_config()
updated = False
for key, value in default_config.items():
if key not in config:
config[key] = value
updated = True
print(f"Added missing parameter: {key}")
# 移除过时的formatting参数
if "formatting" in config:
del config["formatting"]
updated = True
print("Removed deprecated parameter: formatting")
# Save config if updated
if updated:
self._save_config(config)
print("Config file updated")
return config
except Exception as e:
print(f"Failed to read config: {e}")
print("Creating new default config")
return self._create_new_config()
def _create_new_config(self):
"""Create new default config file"""
default_config = self.get_default_config()
if self._save_config(default_config):
print("Default config file created successfully")
return default_config
else:
print("Failed to create config file")
return None
def _save_config(self, config):
"""Save config to file"""
try:
with open(self.config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
indent=2,
width=80)
return True
except Exception as e:
print(f"Failed to save config: {e}")
return False
def update_config(self, updates):
"""Update config parameters"""
config = self.check_and_create_config()
if config is None:
return None
config.update(updates)
if self._save_config(config):
print(f"Config updated: {list(updates.keys())}")
return config
return None
def print_config_info(self, config):
"""Print config information"""
if not config:
return
print("\n" + "=" * 50)
print("Training Configuration")
print("=" * 50)
key_info = [
("Model", "model_name_or_path"),
("Dataset", "dataset"),
("Template", "template"),
("LoRA Rank", "lora_rank"),
("Learning Rate", "learning_rate"),
("Epochs", "num_train_epochs"),
("Batch Size", "per_device_train_batch_size"),
("Output Dir", "output_dir"),
]
for name, key in key_info:
print(f"{name:<12}: {config.get(key, 'N/A')}")
effective_batch = config.get("per_device_train_batch_size", 1) * config.get("gradient_accumulation_steps", 1)
print(f"Effective Batch: {effective_batch}")
print("=" * 50)
def check_environment(self):
"""Check training environment"""
print("Checking training environment...")
# 使用绝对路径 - 直接获取调用者工作目录
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
data_dir = caller_cwd / ".cache" / "data"
# Check data files with detailed info
data_files = [
str(data_dir / "qa.json"),
str(data_dir / "dataset_info.json")
]
for file_path in data_files:
if not Path(file_path).exists():
print(f"❌ Missing data file: {file_path}")
return False
else:
file_size = Path(file_path).stat().st_size
print(f"✅ Found data file: {file_path} ({file_size} bytes)")
# 检查qa.json的内容
try:
import json
qa_file = str(data_dir / "qa.json")
with open(qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
print(f"✅ QA data loaded: {len(qa_data)} samples")
if len(qa_data) == 0:
print("❌ QA data is empty! Please check data generation pipeline.")
return False
# 检查第一个样本的格式
if qa_data:
sample = qa_data[0]
required_keys = ["instruction", "input", "output"]
missing_keys = [key for key in required_keys if key not in sample]
if missing_keys:
print(f"❌ Missing keys in QA sample: {missing_keys}")
return False
else:
print("✅ QA data format is correct")
except Exception as e:
print(f"❌ Error checking QA data: {e}")
return False
print("✅ Data files check passed")
# Check LlamaFactory
try:
import llamafactory
print("✅ LlamaFactory is installed")
except ImportError:
print("❌ LlamaFactory not installed")
print("Please run: pip install llamafactory[torch,metrics]")
return False
return True
def start_training(self):
"""Start training"""
if not self.check_environment():
print("Environment check failed")
return False
# Load config
config = self.check_and_create_config()
if not config:
print("Failed to load config")
return False
# Show config info
self.print_config_info(config)
# Build training command
train_cmd = f"llamafactory-cli train {self.config_path}"
# Start training
print(f"\nStarting training...")
print(f"Config file: {self.config_path}")
print(f"Command: {train_cmd}")
try:
result = subprocess.run(train_cmd, shell=True, check=True)
print("Training completed!")
return True
except subprocess.CalledProcessError as e:
print(f"Training failed: {e}")
return False
except KeyboardInterrupt:
print("\nTraining interrupted by user")
return False
def main():
parser = argparse.ArgumentParser(description="LlamaFactory Training Script")
parser.add_argument("--config", default=None,
help="Config file path (default: cache_dir/.cache/train_config.yaml)")
parser.add_argument("--cache", default="./", help="Cache directory path")
parser.add_argument("--update-lr", type=float,
help="Update learning rate")
parser.add_argument("--update-epochs", type=int,
help="Update training epochs")
parser.add_argument("--update-batch-size", type=int,
help="Update batch size")
args = parser.parse_args()
# 处理cache_base相对路径
cache_path = Path(args.cache)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
# Create trainer - 如果没有指定config,会自动使用cache目录下的默认路径
config_path = args.config
if config_path is None:
config_path = str(cache_path / ".cache" / "train_config.yaml")
trainer = LlamaFactoryTrainer(config_path, cache_base=str(cache_path))
# Update parameters if specified
updates = {}
if args.update_lr:
updates["learning_rate"] = args.update_lr
if args.update_epochs:
updates["num_train_epochs"] = args.update_epochs
if args.update_batch_size:
updates["per_device_train_batch_size"] = args.update_batch_size
if updates:
print("Updating config parameters...")
trainer.update_config(updates)
# Start training
success = trainer.start_training()
if success:
print(f"\nOperation completed!")
print(f"Model saved to: {cache_path / '.cache' / 'saves' / 'qwen2.5_7b_sft_model'}")
else:
sys.exit(1)
if __name__ == "__main__":
main()
\ No newline at end of file
import os
import json
import argparse
from pathlib import Path
from typing import List, Union
class PDFDetector:
"""PDF file detector for scanning directories and generating JSONL config files"""
def __init__(self, output_file: str = "./.cache/gpu/pdf_list.jsonl"):
# Handle output path - based on caller's working directory
output_path = Path(output_file)
if not output_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
output_file = str(caller_cwd / output_file)
self.output_file = output_file
self.pdf_files = []
def scan_directory(self, directory: Union[str, Path], recursive: bool = True) -> List[str]:
"""
Scan PDF files in directory
Args:
directory: Directory path to scan
recursive: Whether to scan subdirectories recursively
Returns:
List of found PDF file paths
"""
directory = Path(directory)
if not directory.exists():
print(f"Error: Directory '{directory}' does not exist")
return []
if not directory.is_dir():
print(f"Error: '{directory}' is not a valid directory")
return []
pdf_files = []
# Directories to exclude from scanning
exclude_dirs = {'.cache', '__pycache__', '.git', 'node_modules', '.venv', 'venv', '.env'}
if recursive:
# Recursively search all subdirectories
pattern = "**/*.pdf"
else:
# Only search current directory
pattern = "*.pdf"
for pdf_path in directory.glob(pattern):
# Skip if path contains any excluded directory
if any(exclude_dir in pdf_path.parts for exclude_dir in exclude_dirs):
continue
# Also skip hidden directories (starting with .)
if any(part.startswith('.') and part != '.' for part in pdf_path.parts):
continue
if pdf_path.is_file():
# Convert to absolute path
pdf_files.append(str(pdf_path.resolve()))
print(f"Found PDF: {pdf_path}")
self.pdf_files.extend(pdf_files)
return pdf_files
def scan_multiple_directories(self, directories: List[Union[str, Path]], recursive: bool = True) -> List[str]:
"""
Scan multiple directories
Args:
directories: List of directory paths
recursive: Whether to scan recursively
Returns:
List of all found PDF file paths
"""
all_pdfs = []
for directory in directories:
pdfs = self.scan_directory(directory, recursive)
all_pdfs.extend(pdfs)
return all_pdfs
def add_pdf_file(self, file_path: Union[str, Path]) -> bool:
"""
Manually add a single PDF file
Args:
file_path: PDF file path
Returns:
Whether successfully added
"""
file_path = Path(file_path)
if not file_path.exists():
print(f"Error: File '{file_path}' does not exist")
return False
if not file_path.is_file():
print(f"Error: '{file_path}' is not a file")
return False
if file_path.suffix.lower() != '.pdf':
print(f"Error: '{file_path}' is not a PDF file")
return False
abs_path = str(file_path.resolve())
if abs_path not in self.pdf_files:
self.pdf_files.append(abs_path)
print(f"Added PDF: {file_path}")
return True
else:
print(f"PDF already exists: {file_path}")
return False
def generate_jsonl(self, output_file: str = None) -> str:
"""
Generate JSONL config file
Args:
output_file: Output file path, if None use the initialized path
Returns:
Generated JSONL file path
"""
if output_file is None:
output_file = self.output_file
else:
# Handle output file relative path - based on caller's working directory
output_path = Path(output_file)
if not output_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
output_path = caller_cwd / output_path
output_file = str(output_path)
if not self.pdf_files:
print("Warning: No PDF files found")
return output_file
# Validate and process output file path
output_path = Path(output_file)
# If output path is directory, auto-generate filename
if output_path.exists() and output_path.is_dir():
output_path = output_path / "pdf_list.jsonl"
output_file = str(output_path)
print(f"Warning: Output path is directory, auto-generating filename: {output_file}")
elif output_path.suffix == "":
# If no extension, add .jsonl
output_path = output_path.with_suffix(".jsonl")
output_file = str(output_path)
print(f"Warning: Auto-adding extension: {output_file}")
# Ensure output directory exists
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
for pdf_path in self.pdf_files:
# Write in JSONL format
json_line = {"raw_content": pdf_path}
f.write(json.dumps(json_line, ensure_ascii=False) + '\n')
print(f"Successfully generated JSONL file: {output_file}")
print(f"Contains {len(self.pdf_files)} PDF files")
return output_file
def preview_results(self, max_items: int = 10):
"""Preview detection results"""
if not self.pdf_files:
print("No PDF files found")
return
print(f"\nDetected {len(self.pdf_files)} PDF files:")
print("-" * 50)
for i, pdf_path in enumerate(self.pdf_files[:max_items]):
print(f"{i + 1:3d}. {pdf_path}")
if len(self.pdf_files) > max_items:
print(f"... and {len(self.pdf_files) - max_items} more files")
print("-" * 50)
def clear_results(self):
"""Clear detection results"""
self.pdf_files.clear()
print("Detection results cleared")
def main():
parser = argparse.ArgumentParser(description='Detect PDF files and generate JSONL config file')
parser.add_argument('input_dir', nargs='?', default='./input',
help='Input directory path to scan (default: ./input)')
parser.add_argument('-o', '--output', default='./.cache/gpu/pdf_list.jsonl',
help='Output JSONL file path (default: ./.cache/gpu/pdf_list.jsonl)')
parser.add_argument('-r', '--recursive', action='store_true', default=True, help='Scan subdirectories recursively')
parser.add_argument('--no-recursive', action='store_false', dest='recursive', help='Do not scan subdirectories')
args = parser.parse_args()
# Validate input directory - handle relative paths
input_path = Path(args.input_dir)
if not input_path.is_absolute():
# If relative path, resolve based on caller's working directory
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
input_path = caller_cwd / input_path
if not input_path.exists():
print(f"Error: Input directory '{input_path}' does not exist")
return
if not input_path.is_dir():
print(f"Error: '{input_path}' is not a valid directory")
return
# Create detector
detector = PDFDetector(args.output)
# Use resolved input directory
input_directory = str(input_path)
# Scan directory
print(f"Starting directory scan: {input_directory}")
print(f"Recursive mode: {'enabled' if args.recursive else 'disabled'}")
detector.scan_directory(input_directory, args.recursive)
# Generate JSONL file
detector.generate_jsonl(output_file=args.output)
if __name__ == "__main__":
main()
\ No newline at end of file
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
from dataflow.operators.knowledge_cleaning import (
KBCChunkGeneratorBatch,
FileOrURLToMarkdownConverterBatch,
KBCTextCleanerBatch,
KBCMultiHopQAGeneratorBatch,
QAExtractor
)
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm
class KBCleaning_batchvllm_GPUPipeline():
def __init__(self, cache_base="./"):
# 处理cache_base相对路径
cache_path = Path(cache_base)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
self.storage = FileStorage(
first_entry_file_name=str(cache_path / ".cache" / "gpu" / "pdf_list.jsonl"),
cache_path=str(cache_path / ".cache" / "gpu"),
file_name_prefix="batch_cleaning_step",
cache_type="json",
)
self.knowledge_cleaning_step1 = FileOrURLToMarkdownConverterBatch(
intermediate_dir=str(cache_path / ".cache"),
lang="en",
mineru_backend="vlm-vllm-engine", # 可选 pipeline, vlm-vllm-engine, vlm-vllm-transformer, vlm-http-client
)
self.knowledge_cleaning_step2 = KBCChunkGeneratorBatch(
split_method="token",
chunk_size=512,
tokenizer_name="./Qwen2.5-7B-Instruct",
)
self.extract_format_qa = QAExtractor(
qa_key="qa_pairs",
output_json_file="./.cache/data/qa.json",
)
def forward(self):
"""执行完整的Pipeline流程"""
print("🔄 Step 1: File/URL to Markdown conversion...")
self.knowledge_cleaning_step1.run(
storage=self.storage.step(),
input_key="raw_content",
output_key="text_path"
)
print("🔄 Step 2: Text splitting into chunks...")
self.knowledge_cleaning_step2.run(
storage=self.storage.step(),
)
print("🔄 Starting LLM serving...")
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="./Qwen2.5-7B-Instruct",
vllm_max_tokens=2048,
vllm_tensor_parallel_size=1, # 使用的GPU数量
vllm_gpu_memory_utilization=0.6, # GPU利用率
vllm_repetition_penalty=1.2
)
self.knowledge_cleaning_step3 = KBCTextCleanerBatch(
llm_serving=self.llm_serving,
lang="en"
)
self.knowledge_cleaning_step4 = KBCMultiHopQAGeneratorBatch(
llm_serving=self.llm_serving,
lang="en",
)
print("🔄 Step 3: Knowledge cleaning...")
self.knowledge_cleaning_step3.run(
storage=self.storage.step(),
)
print("🔄 Step 4: Multi-hop QA generation...")
self.knowledge_cleaning_step4.run(
storage=self.storage.step(),
)
print("🔄 Step 5: Extract and format QA...")
self.extract_format_qa.run(
storage=self.storage.step(),
input_key="question,reasoning_steps",
output_key="answer"
)
print("✅ Pipeline completed! Output saved to: ./.cache/data/qa.json")
def main():
parser = argparse.ArgumentParser(description="PDF to QA Pipeline")
parser.add_argument("--cache", default="./", help="Cache directory path")
args = parser.parse_args()
print("🚀 Starting KB Cleaning Pipeline...")
print(f"📄 Input: {args.cache}.cache/gpu/pdf_list.jsonl")
print(f"💾 Cache: {args.cache}.cache/gpu/")
print(f"📤 Output: {args.cache}.cache/data/qa.json")
print("-" * 60)
model = KBCleaning_batchvllm_GPUPipeline(cache_base=args.cache)
model.forward()
if __name__ == "__main__":
main()
\ No newline at end of file
#!/usr/bin/env python3
"""
LlamaFactory Training Script with YAML Configuration Management
Complete "check-create-read-train" workflow
"""
import yaml
import os
import sys
import subprocess
import argparse
from pathlib import Path
class LlamaFactoryTrainer:
def __init__(self, config_path=None, cache_base="./"):
# 处理cache_base相对路径 - 基于调用者工作目录
cache_path = Path(cache_base)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
self.cache_path = cache_path
# 如果没有指定config_path,使用cache目录下的默认路径
if config_path is None:
config_path = str(cache_path / ".cache" / "train_config.yaml")
else:
# 处理config_path相对路径
config_path_obj = Path(config_path)
if not config_path_obj.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
config_path = str(caller_cwd / config_path_obj)
self.config_path = Path(config_path)
def get_default_config(self):
"""Get default configuration - 基于最新LlamaFactory标准"""
return {
# === 基础配置 ===
"stage": "sft",
"do_train": True,
# === 模型配置 ===
"model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
"template": "qwen",
"trust_remote_code": True,
# === 微调方法配置 ===
"finetuning_type": "lora",
"lora_target": "all", # 使用 "all" 而不是具体层名
"lora_rank": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
# === 数据集配置 ===
"dataset": "kb_qa",
"dataset_dir": str(self.cache_path / ".cache" / "data"),
"cutoff_len": 1024,
"max_samples": None,
"overwrite_cache": True,
"preprocessing_num_workers": 4,
# === 训练配置 ===
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 4,
"learning_rate": 5e-5,
"num_train_epochs": 2.0,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.05,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"optim": "adamw_torch",
# === 输出配置 ===
"output_dir": str(self.cache_path / ".cache" / "saves" / "qwen2.5_7b_sft_model"),
"overwrite_output_dir": True,
"save_only_model": True,
"plot_loss": True,
# === 日志和检查点 ===
"logging_steps": 10,
"save_steps": 300,
"save_total_limit": 2,
# === 评估配置 ===
"val_size": 0.1,
"per_device_eval_batch_size": 2,
"eval_strategy": "steps",
"eval_steps": 300,
# === 硬件配置 ===
"fp16": False,
"bf16": True,
"tf32": True,
"dataloader_num_workers": 2,
# === 其他配置 ===
"seed": 42,
"ddp_timeout": 1800,
"report_to": "none",
"run_name": "qwen2.5_7b_sft_training",
}
def check_and_create_config(self):
"""
Core logic: check config file, read if exists, create if not
"""
# Ensure directory exists
self.config_path.parent.mkdir(parents=True, exist_ok=True)
if self.config_path.exists():
# Config file exists, read it
print(f"Found config file: {self.config_path}")
return self._load_existing_config()
else:
# Config file doesn't exist, create default config
print(f"Creating new config file: {self.config_path}")
return self._create_new_config()
def _load_existing_config(self):
"""Load existing config file"""
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
print("Config file loaded successfully")
# Check if new default parameters need to be added
default_config = self.get_default_config()
updated = False
for key, value in default_config.items():
if key not in config:
config[key] = value
updated = True
print(f"Added missing parameter: {key}")
# 移除过时的formatting参数
if "formatting" in config:
del config["formatting"]
updated = True
print("Removed deprecated parameter: formatting")
# Save config if updated
if updated:
self._save_config(config)
print("Config file updated")
return config
except Exception as e:
print(f"Failed to read config: {e}")
print("Creating new default config")
return self._create_new_config()
def _create_new_config(self):
"""Create new default config file"""
default_config = self.get_default_config()
if self._save_config(default_config):
print("Default config file created successfully")
return default_config
else:
print("Failed to create config file")
return None
def _save_config(self, config):
"""Save config to file"""
try:
with open(self.config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
indent=2,
width=80)
return True
except Exception as e:
print(f"Failed to save config: {e}")
return False
def update_config(self, updates):
"""Update config parameters"""
config = self.check_and_create_config()
if config is None:
return None
config.update(updates)
if self._save_config(config):
print(f"Config updated: {list(updates.keys())}")
return config
return None
def print_config_info(self, config):
"""Print config information"""
if not config:
return
print("\n" + "=" * 50)
print("Training Configuration")
print("=" * 50)
key_info = [
("Model", "model_name_or_path"),
("Dataset", "dataset"),
("Template", "template"),
("LoRA Rank", "lora_rank"),
("Learning Rate", "learning_rate"),
("Epochs", "num_train_epochs"),
("Batch Size", "per_device_train_batch_size"),
("Output Dir", "output_dir"),
]
for name, key in key_info:
print(f"{name:<12}: {config.get(key, 'N/A')}")
effective_batch = config.get("per_device_train_batch_size", 1) * config.get("gradient_accumulation_steps", 1)
print(f"Effective Batch: {effective_batch}")
print("=" * 50)
def check_environment(self):
"""Check training environment"""
print("Checking training environment...")
# 使用绝对路径 - 直接获取调用者工作目录
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
data_dir = caller_cwd / ".cache" / "data"
# Check data files with detailed info
data_files = [
str(data_dir / "qa.json"),
str(data_dir / "dataset_info.json")
]
for file_path in data_files:
if not Path(file_path).exists():
print(f"❌ Missing data file: {file_path}")
return False
else:
file_size = Path(file_path).stat().st_size
print(f"✅ Found data file: {file_path} ({file_size} bytes)")
# 检查qa.json的内容
try:
import json
qa_file = str(data_dir / "qa.json")
with open(qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
print(f"✅ QA data loaded: {len(qa_data)} samples")
if len(qa_data) == 0:
print("❌ QA data is empty! Please check data generation pipeline.")
return False
# 检查第一个样本的格式
if qa_data:
sample = qa_data[0]
required_keys = ["instruction", "input", "output"]
missing_keys = [key for key in required_keys if key not in sample]
if missing_keys:
print(f"❌ Missing keys in QA sample: {missing_keys}")
return False
else:
print("✅ QA data format is correct")
except Exception as e:
print(f"❌ Error checking QA data: {e}")
return False
print("✅ Data files check passed")
# Check LlamaFactory
try:
import llamafactory
print("✅ LlamaFactory is installed")
except ImportError:
print("❌ LlamaFactory not installed")
print("Please run: pip install llamafactory[torch,metrics]")
return False
return True
def start_training(self):
"""Start training"""
if not self.check_environment():
print("Environment check failed")
return False
# Load config
config = self.check_and_create_config()
if not config:
print("Failed to load config")
return False
# Show config info
self.print_config_info(config)
# Build training command
train_cmd = f"llamafactory-cli train {self.config_path}"
# Start training
print(f"\nStarting training...")
print(f"Config file: {self.config_path}")
print(f"Command: {train_cmd}")
try:
result = subprocess.run(train_cmd, shell=True, check=True)
print("Training completed!")
return True
except subprocess.CalledProcessError as e:
print(f"Training failed: {e}")
return False
except KeyboardInterrupt:
print("\nTraining interrupted by user")
return False
def main():
parser = argparse.ArgumentParser(description="LlamaFactory Training Script")
parser.add_argument("--config", default=None,
help="Config file path (default: cache_dir/.cache/train_config.yaml)")
parser.add_argument("--cache", default="./", help="Cache directory path")
parser.add_argument("--update-lr", type=float,
help="Update learning rate")
parser.add_argument("--update-epochs", type=int,
help="Update training epochs")
parser.add_argument("--update-batch-size", type=int,
help="Update batch size")
args = parser.parse_args()
# 处理cache_base相对路径
cache_path = Path(args.cache)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
# Create trainer - 如果没有指定config,会自动使用cache目录下的默认路径
config_path = args.config
if config_path is None:
config_path = str(cache_path / ".cache" / "train_config.yaml")
trainer = LlamaFactoryTrainer(config_path, cache_base=str(cache_path))
# Update parameters if specified
updates = {}
if args.update_lr:
updates["learning_rate"] = args.update_lr
if args.update_epochs:
updates["num_train_epochs"] = args.update_epochs
if args.update_batch_size:
updates["per_device_train_batch_size"] = args.update_batch_size
if updates:
print("Updating config parameters...")
trainer.update_config(updates)
# Start training
success = trainer.start_training()
if success:
print(f"\nOperation completed!")
print(f"Model saved to: {cache_path / '.cache' / 'saves' / 'qwen2.5_7b_sft_model'}")
else:
sys.exit(1)
if __name__ == "__main__":
main()
\ No newline at end of file
#!/usr/bin/env python3
import json
import os
import argparse
from pathlib import Path
def find_input_file(cache_base="./"):
"""Find input file relative to cache_base"""
cache_path = Path(cache_base)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
print(f"Cache base directory: {cache_path}")
# 优先查找 text2qa_step_step3.json
possible_paths = [
cache_path / ".cache" / "gpu" / "text2qa_step_step3.json",
cache_path / ".cache" / "gpu" / "batch_cleaning_step_step4.json",
cache_path / "cache" / "gpu" / "batch_cleaning_step_step4.json",
cache_path / "batch_cleaning_step_step4.json",
]
print("Searching for input file...")
for path in possible_paths:
abs_path = path.resolve()
if abs_path.exists():
size = abs_path.stat().st_size
print(f"Found input file: {abs_path} ({size} bytes)")
return abs_path
else:
print(f"Not found: {abs_path}")
print("Input file not found!")
return None
def load_qa_data_from_files(data_items, input_file):
"""Load QA data from enhanced_chunk_path files"""
all_qa_pairs = []
for i, item in enumerate(data_items):
print(f"Processing item {i + 1}/{len(data_items)}: ", end="")
enhanced_path = item.get('enhanced_chunk_path')
if not enhanced_path:
print("Skip (no enhanced_chunk_path)")
continue
# Convert to absolute path
if not os.path.isabs(enhanced_path):
input_file_path = Path(input_file)
cache_gpu_dir = input_file_path.parent
cache_dir = cache_gpu_dir.parent
project_root = cache_dir.parent
clean_path = enhanced_path.lstrip('./')
enhanced_path = project_root / clean_path
else:
enhanced_path = Path(enhanced_path)
if not enhanced_path.exists():
print(f"Skip (file not exists: {enhanced_path})")
continue
try:
with open(enhanced_path, 'r', encoding='utf-8') as f:
enhanced_data = json.load(f)
chunk_qa_pairs = []
if isinstance(enhanced_data, list):
# enhanced_data is chunk list
for chunk in enhanced_data:
if isinstance(chunk, dict) and 'qa_pairs' in chunk:
qa_data = chunk['qa_pairs']
if isinstance(qa_data, dict) and 'qa_pairs' in qa_data:
# Double nested: chunk['qa_pairs']['qa_pairs']
chunk_qa_pairs.extend(qa_data['qa_pairs'])
elif isinstance(qa_data, list):
# Single nested: chunk['qa_pairs'] is directly a list
chunk_qa_pairs.extend(qa_data)
elif isinstance(enhanced_data, dict) and 'qa_pairs' in enhanced_data:
# enhanced_data is single object
qa_data = enhanced_data['qa_pairs']
if isinstance(qa_data, dict) and 'qa_pairs' in qa_data:
chunk_qa_pairs = qa_data['qa_pairs']
elif isinstance(qa_data, list):
chunk_qa_pairs = qa_data
if chunk_qa_pairs and isinstance(chunk_qa_pairs, list):
print(f"Found {len(chunk_qa_pairs)} QA pairs")
all_qa_pairs.extend(chunk_qa_pairs)
else:
print("Skip (no valid QA data)")
continue
except Exception as e:
print(f"Skip (read failed: {e})")
continue
return all_qa_pairs
def convert_to_alpaca(input_file, output_dir):
"""Convert QA data to Alpaca format for LlamaFactory"""
print(f"Reading data file: {input_file}")
print(f"Output directory: {output_dir}")
# Read main data file
try:
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
print(
f"Successfully read main data, type: {type(data)}, length: {len(data) if hasattr(data, '__len__') else 'N/A'}")
except Exception as e:
print(f"Failed to read file: {e}")
return None
if not isinstance(data, list):
print("Expected data in list format")
return None
# Instruction for the training data
instruction = (
"Please answer the following question based on the provided text content. "
"Your response should:\n"
"1. Provide accurate information from the source material\n"
"2. Include relevant analysis and reasoning\n"
"3. Reference specific details or examples when applicable\n"
"4. Maintain clarity and precision in your explanation\n\n"
"Focus on delivering factual, well-reasoned answers based on the text content."
)
print("Loading QA pairs from enhanced files...")
all_qa_pairs = load_qa_data_from_files(data, input_file)
if not all_qa_pairs:
print("No QA pairs found! Please check data structure")
return None
print(f"Total found {len(all_qa_pairs)} QA pairs")
# Process QA pairs
results = []
processed_pairs = 0
for qa in all_qa_pairs:
if not isinstance(qa, dict):
continue
# Extract question and answer
question = ""
answer_text = ""
# Try different possible field names for question
for q_field in ['question', 'Question', 'query', 'Query']:
if q_field in qa and qa[q_field]:
question = qa[q_field].strip()
break
# Try different possible field names for answer
for a_field in ['answer', 'Answer', 'response', 'Response']:
if a_field in qa and qa[a_field]:
answer_text = qa[a_field].strip()
break
# Skip empty questions or answers
if not question or not answer_text:
continue
# Include reasoning steps if available
reasoning_steps = qa.get("reasoning_steps", [])
reasoning_text = ""
if isinstance(reasoning_steps, list):
reasoning_text = "\n".join([
step.get("step", "").strip()
for step in reasoning_steps
if isinstance(step, dict) and step.get("step", "").strip()
])
# Build output (reasoning process + answer)
if reasoning_text:
output_text = f"{reasoning_text}\n\n{answer_text}"
else:
output_text = answer_text
results.append({
"instruction": instruction,
"input": question,
"output": output_text
})
processed_pairs += 1
print(f"\nProcessing statistics:")
print(f"Found QA pairs: {len(all_qa_pairs)}")
print(f"Valid conversions: {processed_pairs}")
if not results:
print("No QA pairs converted!")
return None
# Ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# Save as qa.json (LlamaFactory standard format)
qa_file = output_dir / "qa.json"
try:
with open(qa_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
file_size = qa_file.stat().st_size
print(f"Conversion complete: {len(results)} QA pairs -> {qa_file} ({file_size} bytes)")
return qa_file
except Exception as e:
print(f"Failed to save file: {e}")
return None
def create_llamafactory_config(output_dir):
"""Create dataset_info.json for LlamaFactory"""
print("Creating LlamaFactory configuration...")
dataset_info = {
"kb_qa": {
"file_name": "qa.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output"
}
}
}
config_file = output_dir / "dataset_info.json"
try:
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(dataset_info, f, ensure_ascii=False, indent=2)
print(f"LlamaFactory configuration created: {config_file}")
return config_file
except Exception as e:
print(f"Failed to create configuration: {e}")
return None
def main():
parser = argparse.ArgumentParser(description="Convert Step 3 QA data to LlamaFactory format")
parser.add_argument("--cache", default="./", help="Cache directory path")
args = parser.parse_args()
# Handle cache_base relative path
cache_path = Path(args.cache)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
print("Step 3 to LlamaFactory Converter")
print("=" * 50)
print(f"Cache base directory: {cache_path}")
# Find input file
input_file = find_input_file(str(cache_path))
if not input_file:
print("\nTips:")
print("1. Ensure Step 3 (Text2QA generation) has been completed")
print("2. Check if .cache/gpu/ directory exists")
print("3. Look for text2qa_step_step3.json file")
exit(1)
# Output directory
output_dir = cache_path / ".cache" / "data"
print(f"\nStarting conversion...")
print(f"Input: {input_file}")
print(f"Output directory: {output_dir}")
print("-" * 50)
# Convert data
qa_file = convert_to_alpaca(input_file, output_dir)
if qa_file:
# Create config file
config_file = create_llamafactory_config(output_dir)
if config_file:
print(f"\nConversion completed successfully!")
print(f"Files created:")
print(f" - {qa_file}")
print(f" - {config_file}")
print("Ready for LlamaFactory training!")
else:
print("Configuration file creation failed")
else:
print("Data conversion failed")
if __name__ == "__main__":
main()
\ No newline at end of file
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
from dataflow.operators.knowledge_cleaning import (
KBCChunkGeneratorBatch,
KBCTextCleanerBatch,
KBCMultiHopQAGeneratorBatch,
QAExtractor
)
# from dataflow.operators.knowledge_cleaning import (
# CorpusTextSplitterBatch,
# KnowledgeCleanerBatch,
# MultiHopQAGeneratorBatch,
# )
from dataflow.utils.storage import FileStorage
from dataflow.serving import LocalModelLLMServing_vllm
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['ONNXRUNTIME_THREAD_AFFINITY'] = 'false'
class Text2QAPipeline:
def __init__(self, cache_base="./"):
# 处理cache_base相对路径
cache_path = Path(cache_base)
if not cache_path.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path = caller_cwd / cache_path
self.storage = FileStorage(
first_entry_file_name=str(cache_path / ".cache" / "gpu" / "text_input.jsonl"),
cache_path=str(cache_path / ".cache" / "gpu"),
file_name_prefix="text2qa_step",
cache_type="json",
)
self.text_splitting_step = KBCChunkGeneratorBatch(
split_method="token",
chunk_size=512,
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
)
self.extract_format_qa = QAExtractor(
qa_key="qa_pairs",
output_json_file="./.cache/data/qa.json",
)
def forward(self):
"""执行完整的Pipeline流程"""
print("Step 1: Text splitting into chunks...")
self.text_splitting_step.run(
storage=self.storage.step(),
)
print("Starting LLM serving...")
self.llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
vllm_max_tokens=2048,
vllm_tensor_parallel_size=1,
vllm_gpu_memory_utilization=0.6,
vllm_repetition_penalty=1.2
)
self.knowledge_cleaning_step = KBCTextCleanerBatch(
llm_serving=self.llm_serving,
lang="en"
)
self.qa_generation_step = KBCMultiHopQAGeneratorBatch(
llm_serving=self.llm_serving,
lang="en"
)
print("Step 2: Knowledge cleaning...")
self.knowledge_cleaning_step.run(
storage=self.storage.step(),
)
print("Step 3: Multi-hop QA generation...")
self.qa_generation_step.run(
storage=self.storage.step(),
)
print("🔄 Step 4: Extract and format QA...")
self.extract_format_qa.run(
storage=self.storage.step(),
input_key="question,reasoning_steps",
output_key="answer"
)
print("Pipeline completed!")
def main():
parser = argparse.ArgumentParser(description="Text to QA Pipeline")
parser.add_argument("--cache", default="./", help="Cache directory path")
args = parser.parse_args()
print("Starting Text to QA Pipeline...")
print(f"Input: {args.cache}.cache/gpu/text_input.jsonl")
print(f"Cache: {args.cache}.cache/gpu/")
print(f"Output: {args.cache}.cache/gpu/text2qa_step_step3.json")
print("-" * 60)
model = Text2QAPipeline(cache_base=args.cache)
model.forward()
if __name__ == "__main__":
main()
\ No newline at end of file
from .operator import OperatorABC, get_operator
from .llm_serving import LLMServingABC
from .wrapper import WrapperABC
from typing import Union, TypeAlias
# 定义类型别名
OPERATOR_CLASSES: TypeAlias = Union[OperatorABC, WrapperABC]
LLM_SERVING_CLASSES: TypeAlias = LLMServingABC # 单一类型也可以这么写
__all__ = [
'OPERATOR_CLASSES',
'LLM_SERVING_CLASSES',
'OperatorABC',
'get_operator',
'LLMServingABC',
'WrapperABC',
]
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Any, List
class LLMServingABC(ABC):
"""Abstract base class for data generators. Which may be used to generate data from a model or API. Called by operators
"""
@abstractmethod
def generate_from_input(self, user_inputs: List[str], system_prompt: str) -> List[str]:
"""
Generate data from input.
input: List[str], the input of the generator
"""
pass
@abstractmethod
def start_serving(self):
"""
Cleanup the generator and garbage collect all GPU/CPU memory.
"""
pass
@abstractmethod
def cleanup(self):
"""
Cleanup the generator and garbage collect all GPU/CPU memory.
"""
pass
def load_model(self, model_name_or_path: str, **kwargs: Any):
"""
Load the model from the given path.
This method is optional and can be overridden by subclasses if needed.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
from abc import ABC, abstractmethod
from dataflow.logger import get_logger
from .prompt import DIYPromptABC, PromptABC
class OperatorABC(ABC):
def __init__(self):
self.logger = get_logger()
self.ALLOWED_PROMPTS = tuple([type[DIYPromptABC | PromptABC]])
@abstractmethod
def run(self) -> None:
"""
Main function to run the operator.
"""
pass
def get_operator(operator_name, args) -> OperatorABC:
from dataflow.utils import OPERATOR_REGISTRY
print(operator_name, args)
operator = OPERATOR_REGISTRY.get(operator_name)(args)
logger = get_logger()
if operator is not None:
logger.info(f"Successfully get operator {operator_name}, args {args}")
else:
logger.error(f"operator {operator_name} is not found")
assert operator is not None
print(operator)
return operator
from typing import TypeVar, Protocol, Union, get_type_hints,cast
from functools import wraps
import inspect
# from dataflow.core import OperatorABC
class PromptABC():
def __init__(self):
pass
def build_prompt(self):
raise NotImplementedError
class DIYPromptABC(PromptABC):
def __init__(self):
super().__init__()
def build_prompt(self):
raise NotImplementedError
# class OperatorWithAllowedPrompts(Protocol):
# ALLOWED_PROMPTS: list[type[DIYPromptABC | PromptABC]]
def _make_diyprompt_union(allowed_prompts: tuple[type[PromptABC], ...]):
"""构造一个 Union 类型,包含允许的 prompt + DIYPromptABC 子类 + None"""
return Union[tuple(allowed_prompts) + (DIYPromptABC, type(None))]
# 泛型参数,表示任意传入的 class 类型
T = TypeVar("T")
def prompt_restrict(*allowed_prompts: type[DIYPromptABC]):
"""
装饰器:限制 prompt_template 只能是指定 Prompt 类 或 DIYPromptABC 子类
并在运行时检查 & 更新 __annotations__(供 get_type_hints 等工具使用)
"""
def decorator(cls:T) -> T:
setattr(cls, "ALLOWED_PROMPTS", tuple(allowed_prompts))
# self.ALLOWED_PROMPTS = list(allowed_prompts)
orig_init = cls.__init__
sig = inspect.signature(orig_init) # 在装饰时就解析一次签名,避免每次实例化重复解析
if "prompt_template" not in sig.parameters:
# 若类的 __init__ 根本没有该形参,就仅维持注解/属性设置,不做运行时检查
# (你也可以选择在这里直接 raise 来强制类必须声明该参数)
pass
@wraps(orig_init)
def new_init(self, *args, **kwargs):
# 用签名绑定实参:自动把位置/关键字/默认值对齐到参数名
try:
bound = sig.bind_partial(self, *args, **kwargs)
bound.apply_defaults()
except TypeError:
# 参数不完整或不匹配时,交给原始 __init__ 去报错更合适
return orig_init(self, *args, **kwargs)
pt = bound.arguments.get("prompt_template", None)
if pt is not None and not isinstance(pt, cls.ALLOWED_PROMPTS):
if not isinstance(pt, DIYPromptABC):
allowed_names = "\n".join(
f" - {c.__module__}.{c.__qualname__}"
for c in cls.ALLOWED_PROMPTS
)
raise TypeError(
f"[{cls.__name__}] Invalid prompt_template type: "
f"{type(pt).__module__}.{type(pt).__qualname__}\n"
f"Expected one of:\n{allowed_names}\n"
f"or a custom subclass of `dataflow.core.prompt.DIYPromptABC.`"
)
return orig_init(self, *args, **kwargs)
cls.__init__ = new_init
# 保持你原本的注解暴露逻辑
cls.__annotations__ = dict(getattr(cls, "__annotations__", {}))
cls.__annotations__["prompt_template"] = _make_diyprompt_union(allowed_prompts)
return cls
return decorator
if __name__ == "__main__":
import pytest
class A(PromptABC): pass
class B(PromptABC): pass
class MyDIY(DIYPromptABC): pass
class Other(PromptABC): pass
@prompt_restrict(A, B)
class Op:
def __init__(self, prompt_template=None):
self.prompt_template = prompt_template
# 关键字参数:允许
Op(prompt_template=A())
Op(prompt_template=B())
Op(prompt_template=MyDIY())
Op() # None 允许
# 位置参数:同样被检测
Op(A()) # ✅
Op(MyDIY()) # ✅
with pytest.raises(TypeError):
Op(Other()) # ❌ 非白名单且非 DIY
with pytest.raises(TypeError):
Op(object()) # ❌ 完全无关类型
from abc import ABC, abstractmethod
from dataflow.logger import get_logger
class WrapperABC(ABC):
"""
Abstract base class for wrappers.
"""
@abstractmethod
def run(self) -> None:
"""
Main function to run the wrapper.
"""
pass
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment