Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# 明确指定源码文件
recursive-include dataflow *.py
recursive-include dataflow *.json
recursive-include dataflow *.yaml
recursive-include dataflow *.sh
recursive-include dataflow *.txt
# 包含样例文件夹下所有类型的文件
# 样例文件夹路径为dataflow/exaple
recursive-include dataflow/example */
# 包含顶层文档文件
include README.md
include requirements.txt
include LICENSE
# 排除无用缓存文件
global-exclude *.py[cod]
global-exclude __pycache__/
global-exclude *.so
global-exclude *.egg-info/
\ No newline at end of file
# DataFlow
DataFlow是一个数据准备系统,旨在从噪声数据源(PDF、纯文本、低质量问答)中解析,生成,加工并评估高质量数据,以提升大语言模型(LLMs)在特定领域的表现,支持预训练、监督微调(SFT)、强化学习训练以及基于知识库的RAG系统。
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.2 |
| python | 3.10.12 |
| transformers | 4.53.3 |
| vllm | 0.9.2+das.opt1.dtk25042 |
| torch | 2.5.1+das.opt1.dtk25042 |
| torchaudio | 2.5.1+das.opt1.dtk25042 |
| torchvision | 0.20.1+das.opt1.dtk25042 |
| flash_mla | 1.0.0+das.opt1.dtk25042 |
##安装
使用DCU实现推理,后端为vllm,命令如下:
```bash
docker run -it --shm-size 60g --network=host --name dataflow --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro -v $PWD:/home/ image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.2-py3.10 bash
git clone https://developer.sourcefind.cn/codes/qteam/dataflow
cd dataflow
pip install -e .[vllm]
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
安装完成后,你可以用如下指令检查安装是否正确:
```bash
dataflow -v
```
如果安装正常,且DataFlow是最新的Release版,则会看到:
```bash
open-dataflow codebase version: 1.0.7
Checking for updates...
Local version : 1.0.7
PyPI version : 1.0.7
You are using the latest version: 1.0.7
```
## 参考资料
- https://github.com/OpenDCAI/DataFlow
from .utils import *
from .version import __version__, version_info
from .logger import get_logger
from .operators import *
from .prompts import *
__all__ = [
'__version__',
'version_info',
'get_logger',
]
def hello():
return "Hello from open-dataflow!"
\ No newline at end of file
#!/usr/bin/env python3
# dataflow/cli.py - Enhanced with local model judge support and eval init/run
# ===============================================================
# DataFlow 命令行入口
# dataflow -v 查看版本并检查更新
# dataflow init [...] 初始化脚本/配置
# dataflow env 查看环境
# WEBUI 已经暂时移除!!!!!
# dataflow webui operators [opts] 启动算子/管线 UI
# dataflow webui agent [opts] 启动 DataFlow-Agent UI(已整合后端)
# dataflow pdf2model init/train PDF to Model 训练流程
# dataflow text2model init/train Text to Model 训练流程
# dataflow chat 聊天界面
# dataflow eval init 初始化评估配置文件
# dataflow eval api 运行API模型评估
# dataflow eval local 运行本地模型评估
# ===============================================================
import os
import argparse
import requests
import sys
import re
import yaml
import json
import subprocess
from pathlib import Path
from colorama import init as color_init, Fore, Style
from dataflow.cli_funcs import cli_env, cli_init # 项目已有工具
from dataflow.version import __version__ # 版本号
color_init(autoreset=True)
PYPI_API_URL = "https://pypi.org/pypi/open-dataflow/json"
# ---------------- 版本检查 ----------------
def version_and_check_for_updates() -> None:
width = os.get_terminal_size().columns
print(Fore.BLUE + "=" * width + Style.RESET_ALL)
print(f"open-dataflow codebase version: {__version__}")
try:
r = requests.get(PYPI_API_URL, timeout=5)
r.raise_for_status()
remote = r.json()["info"]["version"]
print("\tChecking for updates...")
print(f"\tLocal version : {__version__}")
print(f"\tPyPI version : {remote}")
if remote != __version__:
print(Fore.YELLOW + f"New version available: {remote}."
" Run 'pip install -U open-dataflow' to upgrade."
+ Style.RESET_ALL)
else:
print(Fore.GREEN + f"You are using the latest version: {__version__}" + Style.RESET_ALL)
except requests.exceptions.RequestException as e:
print(Fore.RED + "Failed to query PyPI – check your network." + Style.RESET_ALL)
print("Error:", e)
print(Fore.BLUE + "=" * width + Style.RESET_ALL)
# ---------------- 智能聊天功能 ----------------
def check_current_dir_for_model():
"""检查当前目录的模型文件,优先识别微调模型"""
current_dir = Path.cwd()
# 检查 LoRA 适配器文件
adapter_files = [
"adapter_config.json",
"adapter_model.bin",
"adapter_model.safetensors"
]
# 检查基础模型文件
model_files = [
"config.json",
"pytorch_model.bin",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json"
]
# 优先检查adapter(微调模型)
# 如果有adapter文件,就只返回微调模型,不管有没有基础模型文件
if any((current_dir / f).exists() for f in adapter_files):
return [("fine_tuned_model", current_dir)]
# 只有在没有adapter文件时,才检查base model
if any((current_dir / f).exists() for f in model_files):
return [("base_model", current_dir)]
return []
def get_latest_trained_model(cache_path="./"):
"""查找最新训练的模型,支持text2model和pdf2model,按时间戳排序"""
current_dir = Path.cwd()
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
cache_path_obj = current_dir / cache_path_obj
saves_dir = cache_path_obj / ".cache" / "saves"
if not saves_dir.exists():
return None, None
all_models = []
for dir_path in saves_dir.iterdir():
if not dir_path.is_dir():
continue
model_type = None
timestamp = None
# 检查text2model格式 (text2model_cache_YYYYMMDD_HHMMSS)
if dir_path.name.startswith('text2model_cache_'):
timestamp_part = dir_path.name.replace('text2model_cache_', '')
if len(timestamp_part) == 15 and timestamp_part[8] == '_':
date_part = timestamp_part[:8]
time_part = timestamp_part[9:]
if date_part.isdigit() and time_part.isdigit() and len(time_part) == 6:
model_type = 'text2model'
timestamp = timestamp_part
# 检查pdf2model格式 (pdf2model_cache_YYYYMMDD_HHMMSS)
elif dir_path.name.startswith('pdf2model_cache_'):
timestamp_part = dir_path.name.replace('pdf2model_cache_', '')
if len(timestamp_part) == 15 and timestamp_part[8] == '_':
date_part = timestamp_part[:8]
time_part = timestamp_part[9:]
if date_part.isdigit() and time_part.isdigit() and len(time_part) == 6:
model_type = 'pdf2model'
timestamp = timestamp_part
# 检查其他可能的模型目录
else:
# 尝试从目录名提取时间戳
timestamp_match = re.search(r'(\d{8}_\d{6})', dir_path.name)
if timestamp_match:
model_type = 'pdf2model' # 默认为pdf2model
timestamp = timestamp_match.group(1)
elif 'qwen' in dir_path.name.lower() or 'model' in dir_path.name.lower():
# 如果找不到时间戳但看起来像模型目录,使用修改时间
model_type = 'pdf2model' # 默认为pdf2model
mtime = dir_path.stat().st_mtime
# 将修改时间转换为timestamp格式以便排序
import datetime
dt = datetime.datetime.fromtimestamp(mtime)
timestamp = dt.strftime("%Y%m%d_%H%M%S")
if model_type and timestamp:
all_models.append((dir_path, model_type, timestamp))
if not all_models:
return None, None
# 按时间戳排序,最新的在前(不管是什么类型的模型)
all_models.sort(key=lambda x: x[2], reverse=True)
latest_model_path, model_type, timestamp = all_models[0]
return latest_model_path, model_type
def call_dataflow_chat(model_path, model_type=None):
"""调用dataflow的聊天功能(用于微调模型)"""
# 判断模型类型
if model_type is None:
# 从路径判断类型
path_str = str(model_path)
if 'text2model' in path_str:
model_type = 'text2model'
elif 'pdf2model' in path_str:
model_type = 'pdf2model'
else:
# 无法判断,默认尝试text2model
model_type = 'text2model'
if model_type == 'text2model':
try:
from dataflow.cli_funcs.cli_text import cli_text2model_chat
return cli_text2model_chat(str(model_path))
except ImportError:
print("Cannot find text model chat function")
return False
elif model_type == 'pdf2model':
try:
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_chat
return cli_pdf2model_chat(str(model_path))
except ImportError:
print("Cannot find PDF model chat function")
return False
else:
print(f"Unknown model type: {model_type}")
return False
def call_llamafactory_chat(model_path):
"""调用llamafactory的聊天功能(用于基础模型)"""
import subprocess
chat_cmd = [
"llamafactory-cli", "chat",
"--model_name_or_path", str(model_path)
]
try:
result = subprocess.run(chat_cmd, check=True)
return True
except subprocess.CalledProcessError as e:
print(f"LlamaFactory chat failed: {e}")
return False
except FileNotFoundError:
print("llamafactory-cli not found. Please install LlamaFactory:")
print("pip install llamafactory[torch,metrics]")
return False
def smart_chat_command(model_path=None, cache_path="./"):
"""智能聊天命令,统一处理各种模型类型,不自动下载"""
if model_path:
# 如果明确指定了模型路径,直接使用
model_path_obj = Path(model_path)
if not model_path_obj.exists():
print(f"Specified model path does not exist: {model_path}")
return False
print(f"{Fore.CYAN}Using specified model: {model_path}{Style.RESET_ALL}")
# 检查是否有adapter文件
adapter_files = [
"adapter_config.json",
"adapter_model.bin",
"adapter_model.safetensors"
]
has_adapter = any((model_path_obj / f).exists() for f in adapter_files)
if has_adapter:
# 有adapter,使用dataflow chat
return call_dataflow_chat(model_path)
else:
# 没有adapter,使用llamafactory chat
return call_llamafactory_chat(model_path)
# 检查当前目录
detected_models = check_current_dir_for_model()
if detected_models:
# 优先使用fine_tuned_model(adapter)
for model_type, path in detected_models:
if model_type == "fine_tuned_model":
print(f"{Fore.GREEN}Found trained model in current directory: {path.name}{Style.RESET_ALL}")
return call_dataflow_chat(path)
# 如果没有adapter,使用base_model
for model_type, path in detected_models:
if model_type == "base_model":
print(f"{Fore.YELLOW}Found base model in current directory: {path.name}{Style.RESET_ALL}")
print(f"{Fore.CYAN}Starting chat interface...{Style.RESET_ALL}")
return call_llamafactory_chat(path)
# 检查缓存中的训练模型
latest_model, model_type = get_latest_trained_model(cache_path)
if latest_model:
model_name = Path(latest_model).name
print(f"{Fore.GREEN}Found trained model from cache: {model_name}{Style.RESET_ALL}")
print(f"{Fore.CYAN}Starting chat interface...{Style.RESET_ALL}")
# 检查缓存中的模型是否有adapter文件
latest_model_path = Path(latest_model)
adapter_files = [
"adapter_config.json",
"adapter_model.bin",
"adapter_model.safetensors"
]
has_adapter = any((latest_model_path / f).exists() for f in adapter_files)
if has_adapter:
return call_dataflow_chat(latest_model, model_type)
else:
print(f"No adapter files found in {latest_model}")
print("This doesn't appear to be a trained model directory.")
return False
# 如果什么都没找到,给出提示而不下载
print("No model found in current directory or cache.")
print()
print("Options:")
print("1. Train a model first:")
print(" dataflow text2model init && dataflow text2model train")
print(" dataflow pdf2model init && dataflow pdf2model train")
print()
print("2. Use an existing model:")
print(" dataflow chat --model /path/to/your/model")
print()
print("3. Download a model manually and place it in current directory")
return False
# ---------------- 新的eval命令处理函数 ----------------
def handle_python_config_init():
"""处理Python配置文件初始化"""
try:
from dataflow.cli_funcs.cli_eval import DataFlowEvalCLI
cli = DataFlowEvalCLI()
success = cli.init_eval_files() # 使用正确的方法名(复数)且无参数
if success:
print("Configuration files initialized successfully")
else:
print("Configuration files initialization failed")
return success
except ImportError as e:
print(f"Python config evaluation module unavailable: {e}")
print("Please check if dataflow.cli_funcs.cli_eval module exists")
return False
except Exception as e:
print(f"Configuration file initialization failed: {e}")
return False
def handle_python_config_eval(eval_type: str, args=None):
"""处理Python配置文件评估模式"""
try:
from dataflow.cli_funcs.cli_eval import DataFlowEvalCLI
cli = DataFlowEvalCLI()
# 使用默认文件名
eval_file = f"eval_{eval_type}.py"
print(f"Starting {eval_type} model evaluation: {eval_file}")
# 传递命令行参数到评估器
success = cli.run_eval_file(eval_file)
if success:
print(f"{eval_type.upper()} model evaluation completed successfully")
else:
print(f"{eval_type.upper()} model evaluation failed")
return success
except ImportError as e:
print(f"Python config evaluation module unavailable: {e}")
print("Please check if dataflow.cli_funcs.cli_eval module exists")
return False
except Exception as e:
print(f"Python config evaluation failed: {e}")
return False
def handle_eval_command(args):
"""处理评估命令 - 支持自动检测和模型指定"""
try:
eval_action = getattr(args, 'eval_action', None)
# 处理 init 子命令
if eval_action == 'init':
return handle_python_config_init()
# 处理 api 子命令
elif eval_action == 'api':
return handle_python_config_eval('api', args)
# 处理 local 子命令
elif eval_action == 'local':
return handle_python_config_eval('local', args)
# 如果没有指定子命令,显示帮助
else:
print("DataFlow Evaluation Tool")
print()
print("Available commands:")
print(" dataflow eval init # Initialize evaluation config files")
print(" dataflow eval api # Run API model evaluation (auto-detect models)")
print(" dataflow eval local # Run local model evaluation (auto-detect models)")
print()
print("Complete evaluation workflow:")
print(" 1. dataflow eval local # Auto-detect and evaluate local models")
print(" 2. View generated evaluation report # model_comparison_report.json")
print()
print("Config file descriptions:")
print(" - eval_api.py: API evaluator config (GPT-4o etc.)")
print(" - eval_local.py: Local evaluator config")
return False
except Exception as e:
print(f"Evaluation command execution failed: {e}")
import traceback
traceback.print_exc()
return False
# ---------------- CLI 主函数 ----------------
def build_arg_parser() -> argparse.ArgumentParser:
"""构建参数解析器"""
parser = argparse.ArgumentParser(
prog="dataflow",
description=f"DataFlow Command-Line Interface (v{__version__})",
)
parser.add_argument("-v", "--version", action="store_true", help="Show version and exit")
# ============ 顶层子命令 ============ #
top = parser.add_subparsers(dest="command", required=False)
# --- init ---
p_init = top.add_parser("init", help="Initialize scripts/configs in current dir")
p_init_sub = p_init.add_subparsers(dest="subcommand", required=False)
p_init_sub.add_parser("all", help="Init all components").set_defaults(subcommand="all")
p_init_sub.add_parser("reasoning", help="Init reasoning components").set_defaults(subcommand="reasoning")
# --- env ---
top.add_parser("env", help="Show environment information")
# --- chat ---
p_chat = top.add_parser("chat", help="Start chat interface with trained model")
p_chat.add_argument("--model", default=None, help="Model path (default: use latest trained model from cache)")
p_chat.add_argument("--cache", default="./", help="Cache directory path")
# --- eval 命令
p_eval = top.add_parser("eval", help="Model evaluation using BenchDatasetEvaluator")
eval_sub = p_eval.add_subparsers(dest="eval_action", help="Evaluation actions")
# eval init 子命令
eval_init = eval_sub.add_parser("init", help="Initialize evaluation configuration file")
# eval api 子命令
eval_api = eval_sub.add_parser("api", help="Run API model evaluation")
# eval local 子命令
eval_local = eval_sub.add_parser("local", help="Run local model evaluation")
# --- pdf2model ---
p_pdf2model = top.add_parser("pdf2model", help="PDF to model training pipeline")
p_pdf2model.add_argument("--cache", default="./", help="Cache directory path")
p_pdf2model_sub = p_pdf2model.add_subparsers(dest="pdf2model_action", required=True)
p_pdf2model_init = p_pdf2model_sub.add_parser("init", help="Initialize PDF to model pipeline")
p_pdf2model_train = p_pdf2model_sub.add_parser("train", help="Start training after PDF processing")
p_pdf2model_train.add_argument("--lf_yaml", default=None,
help="LlamaFactory config file (default: {cache}/.cache/train_config.yaml)")
# --- text2model ---
p_text2model = top.add_parser("text2model", help="Train model from JSON/JSONL data")
p_text2model_sub = p_text2model.add_subparsers(dest="text2model_action", required=True)
p_text2model_init = p_text2model_sub.add_parser("init", help="Initialize text2model pipeline")
p_text2model_init.add_argument("--cache", default="./", help="Cache directory path")
p_text2model_train = p_text2model_sub.add_parser("train", help="Start training after text processing")
p_text2model_train.add_argument('input_dir', nargs='?', default='./',
help='Input directory to scan (default: ./)')
p_text2model_train.add_argument('--input-keys', default=None,
help='Fields to process (default: text)')
p_text2model_train.add_argument("--lf_yaml", default=None,
help="LlamaFactory config file (default: {cache}/.cache/train_config.yaml)")
# --- webui ---
p_webui = top.add_parser("webui", help="Launch Gradio WebUI")
p_webui.add_argument("-H", "--host", default="127.0.0.1", help="Bind host (default 127.0.0.1)")
p_webui.add_argument("-P", "--port", type=int, default=7862, help="Port (default 7862)")
p_webui.add_argument("--show-error", action="store_true", help="Show Gradio error tracebacks")
# webui 二级子命令:operators / agent
w_sub = p_webui.add_subparsers(dest="ui_mode", required=False)
w_sub.add_parser("operators", help="Launch operator / pipeline UI")
w_sub.add_parser("agent", help="Launch DataFlow-Agent UI (backend included)")
w_sub.add_parser("pdf", help="Launch PDF Knowledge Base Cleaning UI")
return parser
def main() -> None:
"""主入口函数"""
parser = build_arg_parser()
args = parser.parse_args()
# ---------- 顶层逻辑分发 ----------
if args.version:
version_and_check_for_updates()
return
if args.command == "init":
cli_init(subcommand=args.subcommand or "base")
elif args.command == "env":
cli_env()
elif args.command == "eval":
handle_eval_command(args)
elif args.command == "pdf2model":
if args.pdf2model_action == "init":
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_init
cli_pdf2model_init(cache_path=args.cache)
elif args.pdf2model_action == "train":
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_train
# If no lf_yaml specified, use default path relative to cache
lf_yaml = args.lf_yaml or f"{args.cache}/.cache/train_config.yaml"
cli_pdf2model_train(lf_yaml=lf_yaml, cache_path=args.cache)
elif args.command == "text2model":
from dataflow.cli_funcs.cli_text import cli_text2model_init, cli_text2model_train
if args.text2model_action == "init":
cli_text2model_init(cache_path=getattr(args, 'cache', './'))
elif args.text2model_action == "train":
# 如果没有指定lf_yaml,使用默认路径
lf_yaml = getattr(args, 'lf_yaml', None) or "./.cache/train_config.yaml"
cli_text2model_train(input_keys=getattr(args, 'input_keys', None), lf_yaml=lf_yaml)
elif args.command == "chat":
smart_chat_command(model_path=args.model, cache_path=args.cache)
elif args.command == "webui":
# 默认使用 operators
mode = args.ui_mode or "operators"
if mode == "operators":
print("Currently webui is under maintenance. Please check back later.")
# from dataflow.webui.operator_pipeline import demo
# demo.launch(
# server_name=args.host,
# server_port=args.port,
# show_error=args.show_error,
# )
elif mode == "agent":
print("Agent UI is deprecated in Dataflow main repo, please use the dedicated https://github.com/OpenDCAI/DataFlow-Agent repo.")
elif mode == "pdf":
print("Currently webui is under maintenance. Please check back later.")
# from dataflow.webui import kbclean_webui
# kbclean_webui.create_ui().launch()
else:
parser.error(f"Unknown ui_mode {mode!r}")
if __name__ == "__main__":
main()
\ No newline at end of file
from .cli_init import cli_init
from .cli_env import cli_env
__all__ = [
"cli_env",
"cli_init"
]
\ No newline at end of file
import os
import torch
import platform
from colorama import init, Fore, Style
from dataflow import __version__
import importlib.metadata
from .paths import DataFlowPath
def is_torch_cuda_available():
"""
Check if CUDA is available for PyTorch.
"""
return torch.cuda.is_available()
def get_env_info():
info = {
"`Dataflow` version": __version__,
"`Dataflow` install path": DataFlowPath.get_dataflow_dir(),
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Torchvision version": torch.__version__,
}
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
info["GPU number"] = torch.cuda.device_count()
info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB"
try:
import deepspeed # type: ignore
info["DeepSpeed version"] = deepspeed.__version__
except Exception:
pass
try:
import bitsandbytes # type: ignore
info["Bitsandbytes version"] = bitsandbytes.__version__
except Exception:
pass
try:
import vllm
info["vLLM version"] = vllm.__version__
except Exception:
pass
try:
import sglang
info["SGLang version"] = sglang.__version__
except Exception:
pass
try:
mineru_version = importlib.metadata.version("mineru")
info["MinerU version"] = mineru_version
except Exception:
pass
try:
import subprocess
# get the dir of imdlbenco package
imdlbenco_dir = os.path.dirname(os.path.abspath(__file__))
# move to this dir and get the git commit hash in a subprocess
# but don't change the current working directory
os.chdir(imdlbenco_dir)
commit_info = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
commit_hash = commit_info.stdout.strip()
info["Git commit"] = commit_hash
except Exception:
pass
print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
def cli_env():
get_env_info()
if __name__ == "__main__":
print(get_env_info())
\ No newline at end of file
# dataflow/cli_funcs/cli_eval.py
"""DataFlow 评估工具"""
import os
import json
import shutil
import importlib.util
from pathlib import Path
from typing import List, Dict, Any
from datetime import datetime
from dataflow import get_logger
from dataflow.serving import LocalModelLLMServing_vllm
from dataflow.operators.reasoning import ReasoningAnswerGenerator
from dataflow.prompts.reasoning.diy import DiyAnswerGeneratorPrompt
from dataflow.utils.storage import FileStorage
import torch
import gc
logger = get_logger()
DEFAULT_ANSWER_PROMPT = """Please answer the following question based on the provided academic literature. Your response should:
1. Provide accurate information from the source material
2. Include relevant scientific reasoning and methodology
3. Reference specific findings, data, or conclusions when applicable
4. Maintain academic rigor and precision in your explanation
Question: {question}
Answer:"""
class EvaluationPipeline:
"""评估管道"""
def __init__(self, config: Dict[str, Any]):
self.config = config
# self.cli_args = cli_args
self.prepared_models = []
self.generated_files = []
def run(self) -> bool:
try:
# 1. 获取目标模型
self.target_models = self._get_target_models()
if not self.target_models:
logger.error("No TARGET_MODELS found in config")
return False
self.prepared_models = self._prepare_models()
if not self.prepared_models:
return False
# 2. 生成答案
self.generated_files = self._generate_answers()
if not self.generated_files:
return False
# 3. 执行评估
results = self._run_evaluation()
# 4. 生成报告
self._generate_report(results)
return True
except Exception as e:
logger.error(f"Evaluation failed: {e}")
import traceback
traceback.print_exc()
return False
def _get_target_models(self) -> List:
"""获取目标模型列表"""
target_config = self.config.get("TARGET_MODELS", [])
if not isinstance(target_config, list):
logger.error(f"TARGET_MODELS must be a list, got {type(target_config)}")
return []
if not target_config:
logger.error("TARGET_MODELS is empty")
return []
return target_config
def _prepare_models(self) -> List[Dict]:
"""准备模型信息"""
prepared = []
default_config = self.config.get("DEFAULT_MODEL_CONFIG", {})
for idx, item in enumerate(self.target_models, 1):
if isinstance(item, str):
model_info = {
"name": Path(item).name,
"path": item,
"type": "local",
**default_config
}
elif isinstance(item, dict):
if "path" not in item:
logger.error(f"Model at index {idx} missing 'path'")
continue
model_info = {
**default_config, # 1. 先设置默认值
**item, # 2. 用户配置覆盖默认值
"name": item.get("name", Path(item["path"]).name), # 3. 确保name字段正确
"type": "local" # 4. 强制设置type
}
else:
logger.error(f"Invalid model format at index {idx}")
continue
prepared.append(model_info)
return prepared
def _clear_vllm_cache(self):
"""清理 vLLM 缓存"""
cache_paths = [
Path.home() / ".cache" / "vllm" / "torch_compile_cache",
Path.home() / ".cache" / "vllm"
]
for cache_path in cache_paths:
if cache_path.exists():
try:
shutil.rmtree(cache_path)
except Exception as e:
logger.warning(f"Failed to clear cache: {e}")
def _generate_answers(self) -> List[Dict]:
"""生成模型答案"""
generated_files = []
data_config = self.config.get("DATA_CONFIG", {})
input_file = data_config.get("input_file", "./.cache/data/qa.json")
if not Path(input_file).exists():
logger.error(f"Input file not found: {input_file}")
return []
self._clear_vllm_cache()
for idx, model_info in enumerate(self.prepared_models, 1):
llm_serving = None
answer_generator = None
storage = None
try:
logger.info(f"[{idx}/{len(self.prepared_models)}] Processing: {model_info['name']}")
cache_dir = model_info.get('cache_dir', './.cache/eval')
Path(cache_dir).mkdir(parents=True, exist_ok=True)
output_file = f"{cache_dir}/answers_{model_info['name']}.json"
# 加载模型
llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path=model_info['path'],
vllm_tensor_parallel_size=model_info.get('tensor_parallel_size', 2),
vllm_max_tokens=model_info.get('max_tokens', 1024),
vllm_gpu_memory_utilization=model_info.get('gpu_memory_utilization', 0.8)
)
# 答案生成器
custom_prompt = model_info.get('answer_prompt', DEFAULT_ANSWER_PROMPT)
answer_generator = ReasoningAnswerGenerator(
llm_serving=llm_serving,
prompt_template=DiyAnswerGeneratorPrompt(custom_prompt)
)
# 存储
cache_path = f"{cache_dir}/{model_info['name']}_generation"
storage = FileStorage(
first_entry_file_name=input_file,
cache_path=cache_path,
file_name_prefix=model_info.get('file_prefix', 'answer_gen'),
cache_type=model_info.get('cache_type', 'json')
)
# 运行生成
answer_generator.run(
storage=storage.step(),
input_key=data_config.get("question_key", "input"),
output_key=model_info.get('output_key', 'model_generated_answer')
)
# 保存结果
file_prefix = model_info.get('file_prefix', 'answer_gen')
cache_type = model_info.get('cache_type', 'json')
# 查找所有匹配的文件
pattern = f"{file_prefix}_step*.{cache_type}"
matching_files = sorted(Path(cache_path).glob(pattern))
if matching_files:
# 使用最新的文件(最后一个step)
gen_file = matching_files[-1]
shutil.copy2(gen_file, output_file)
generated_files.append({
"model_name": model_info['name'],
"model_path": model_info['path'],
"file_path": output_file
})
else:
logger.error(f"No generated file found for {model_info['name']} in {cache_path}")
continue
except Exception as e:
logger.error(f"Failed to process {model_info['name']}: {e}")
continue
finally:
if answer_generator is not None:
del answer_generator
if storage is not None:
del storage
if llm_serving is not None:
del llm_serving
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
return generated_files
def _run_evaluation(self) -> List[Dict]:
"""运行评估"""
try:
judge_serving = self.config["create_judge_serving"]()
except Exception as e:
logger.error(f"Failed to create judge: {e}")
return []
results = []
eval_config = self.config.get("EVALUATOR_RUN_CONFIG", {})
for file_info in self.generated_files:
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
result_file = f"./eval_results/{timestamp}_{file_info['model_name']}/result.json"
Path(result_file).parent.mkdir(parents=True, exist_ok=True)
storage = self.config["create_storage"](
file_info["file_path"],
f"./.cache/eval/{file_info['model_name']}"
)
evaluator = self.config["create_evaluator"](judge_serving, result_file)
evaluator.run(
storage=storage.step(),
input_test_answer_key=eval_config.get("input_test_answer_key", "model_generated_answer"),
input_gt_answer_key=eval_config.get("input_gt_answer_key", "output"),
input_question_key=eval_config.get("input_question_key", "input")
)
if Path(result_file).exists():
with open(result_file, 'r') as f:
data = json.load(f)
if data:
data[0]["model_name"] = file_info['model_name']
results.append(data[0])
except Exception as e:
logger.error(f"Eval failed for {file_info['model_name']}: {e}")
continue
return results
def _generate_report(self, results: List[Dict]):
"""生成报告"""
if not results:
logger.warning("No results")
return
sorted_results = sorted(results, key=lambda x: x.get("accuracy", 0), reverse=True)
print("\n" + "=" * 60)
print("Model Evaluation Results")
print("=" * 60)
for i, r in enumerate(sorted_results, 1):
print(f"{i}. {r['model_name']}")
print(f" Accuracy: {r.get('accuracy', 0):.3f}")
print(f" Total: {r.get('total_samples', 0)}")
print(f" Matched: {r.get('matched_samples', 0)}")
print()
print("=" * 60)
# 保存详细报告
report_file = "./eval_results/report.json"
Path(report_file).parent.mkdir(parents=True, exist_ok=True)
with open(report_file, 'w') as f:
json.dump({"results": sorted_results}, f, indent=2)
print(f"Detailed report: {report_file}")
class DataFlowEvalCLI:
"""CLI工具"""
def __init__(self):
self.current_dir = Path.cwd()
def _get_template_path(self, eval_type: str) -> Path:
current_file = Path(__file__)
dataflow_dir = current_file.parent.parent
return dataflow_dir / "cli_funcs" / "eval_pipeline" / f"eval_{eval_type}.py"
def init_eval_files(self):
"""初始化配置文件"""
files = [("eval_api.py", "api"), ("eval_local.py", "local")]
existing = [f for f, _ in files if (self.current_dir / f).exists()]
if existing:
if input(f"{', '.join(existing)} exists. Overwrite? (y/n): ").lower() != 'y':
return False
for filename, eval_type in files:
try:
template = self._get_template_path(eval_type)
if not template.exists():
logger.error(f"Template not found: {template}")
continue
shutil.copy2(template, self.current_dir / filename)
logger.info(f"Created: {filename}")
except Exception as e:
logger.error(f"Failed: {e}")
logger.info("You must modified the eval_api.py or eval_local.py before you run dataflow eval api/local")
return True
def run_eval_file(self, eval_file: str):
"""运行评估"""
config_path = self.current_dir / eval_file
if not config_path.exists():
logger.error(f"Config not found: {eval_file}")
return False
try:
spec = importlib.util.spec_from_file_location("config", config_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
config = module.get_evaluator_config()
return run_evaluation(config)
except Exception as e:
logger.error(f"Failed: {e}")
return False
def run_evaluation(config):
"""运行评估"""
pipeline = EvaluationPipeline(config)
return pipeline.run()
\ No newline at end of file
import os
import re
from colorama import init, Fore, Style
from .paths import DataFlowPath
from .copy_funcs import copy_files_without_recursion, copy_file, copy_files_recursively
def _copy_scripts():
target_dir = os.getcwd()
if not os.path.exists(target_dir):
os.makedirs(target_dir)
# script_path = DataFlowPath.get_dataflow_scripts_dir()
copy_files_recursively(DataFlowPath.get_dataflow_scripts_dir(), target_dir)
def _copy_pipelines():
target_dir = os.getcwd()
if not os.path.exists(target_dir):
os.makedirs(target_dir)
copy_files_recursively(DataFlowPath.get_dataflow_pipelines_dir(), target_dir)
# Copy pipelines
def _copy_playground():
target_dir = os.getcwd()
if not os.path.exists(target_dir):
os.makedirs(target_dir)
copy_files_recursively(DataFlowPath.get_dataflow_playground_dir(), target_dir)
def _copy_examples():
target_dir = os.path.join(os.getcwd(), "example_data")
if not os.path.exists(target_dir):
os.makedirs(target_dir)
copy_files_recursively(DataFlowPath.get_dataflow_example_dir(), target_dir)
def cli_init(subcommand):
print(f'{Fore.GREEN}Initializing in current working directory...{Style.RESET_ALL}')
# base initialize that only contain default scripts
if subcommand == "base":
_copy_pipelines()
_copy_examples()
_copy_playground()
# if subcommand == "model_zoo":
# _copy_train_scripts()
# _copy_demo_runs()
# _copy_demo_configs()
# _copy_dataset_json()
# # base initialize that only contain default scripts
# if subcommand == "backbone":
# _copy_train_scripts()
# _copy_demo_runs()
# _copy_demo_configs()
# _copy_dataset_json()
# print(f'{Fore.GREEN}Successfully initialized IMDLBenCo scripts.{Style.RESET_ALL}')
\ No newline at end of file
#!/usr/bin/env python3
"""
DataFlow PDF2Model CLI Module - dataflow/cli_funcs/cli_pdf.py
PDF to Model training pipeline with init/train/chat commands
"""
import subprocess
import sys
import yaml
import json
import os
import datetime
from pathlib import Path
from colorama import Fore, Style
from dataflow import get_logger
from .paths import DataFlowPath
logger = get_logger()
def run_script_with_args(script_path: Path, description: str, args: list = None, cwd: str = None) -> bool:
"""Run a Python script with arguments and real-time output"""
print(f"\n{Fore.BLUE}{description}{Style.RESET_ALL}")
cmd = [sys.executable, str(script_path)]
if args:
cmd.extend(args)
print(f"Running: {' '.join(cmd)}")
if cwd:
print(f"Working directory: {cwd}")
try:
result = subprocess.run(cmd, cwd=cwd, check=True,
stdout=sys.stdout, stderr=sys.stderr, text=True)
print(f"{Fore.GREEN}{description} completed{Style.RESET_ALL}")
return True
except subprocess.CalledProcessError as e:
print(f"{Fore.RED}{description} failed{Style.RESET_ALL}")
return False
def get_dataflow_script_path(script_name: str) -> Path:
"""Get the path of dataflow built-in scripts"""
try:
import dataflow
dataflow_path = Path(dataflow.__file__).parent
# PDF2Model 脚本在 dataflow/cli_funcs/pdf2model_pipeline/ 目录下
pdf2model_path = dataflow_path / "cli_funcs" / "pdf2model_pipeline" / script_name
if pdf2model_path.exists():
return pdf2model_path
# 检查其他可能的路径
possible_dirs = [
dataflow_path / "templates" / "pdf2model_pipeline",
dataflow_path / "pipeline_templates"
]
for dir_path in possible_dirs:
script_path = dir_path / script_name
if script_path.exists():
return script_path
return None
except:
return None
def copy_customizable_scripts():
"""Only copy scripts that users might want to customize"""
print("Step 0: Copying customizable pipeline script...")
current_dir = Path(os.getcwd())
try:
# 只复制用户可能需要自定义的脚本
scripts_to_copy = [
"pdf_to_qa_pipeline.py" # 用户可能需要修改 vLLM/sglang 配置
]
import shutil
copied_files = []
for script_name in scripts_to_copy:
source_path = get_dataflow_script_path(script_name)
if source_path is None:
print(f"Warning: Template not found: {script_name}")
continue
target_file = current_dir / script_name
shutil.copy2(source_path, target_file)
copied_files.append(script_name)
print(f"Copied: {script_name}")
if copied_files:
print(f"Successfully copied {len(copied_files)} customizable script(s)")
print("You can now modify these files (e.g., switch vLLM/sglang in pdf_to_qa_pipeline.py)")
return True
else:
print("No customizable scripts were copied")
return False
except Exception as e:
print(f"Failed to copy scripts: {e}")
return False
def create_train_config_yaml(cache_path="./", model_name_or_path="Qwen/Qwen2.5-7B-Instruct"):
"""Create train_config.yaml file using built-in LlamaFactory configuration"""
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path_obj = caller_cwd / cache_path_obj
# 生成时间戳
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir_name = f"pdf2model_cache_{timestamp}" # 改为pdf2model_cache前缀
cache_dir = cache_path_obj / ".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
config_file = cache_dir / "train_config.yaml"
try:
# 使用内置的 LlamaFactory.py 获取默认配置
llamafactory_script_path = get_dataflow_script_path("llama_factory_trainer.py")
if llamafactory_script_path is None:
print("Built-in llama_factory_trainer.py not found")
return None
import importlib.util
spec = importlib.util.spec_from_file_location("llamafactory_trainer", llamafactory_script_path)
llamafactory_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(llamafactory_module)
# 创建trainer实例并获取默认配置
trainer = llamafactory_module.LlamaFactoryTrainer(str(config_file), str(cache_path_obj))
config = trainer.get_default_config()
# 只更新必要的动态参数
config["model_name_or_path"] = model_name_or_path
config["output_dir"] = str(cache_path_obj / ".cache" / "saves" / model_dir_name)
config["dataset_dir"] = str(cache_path_obj / ".cache" / "data")
# 根据模型类型设置模板
if "qwen" in model_name_or_path.lower():
config["template"] = "qwen"
elif "llama" in model_name_or_path.lower():
config["template"] = "llama3"
elif "chatglm" in model_name_or_path.lower():
config["template"] = "chatglm3"
elif "baichuan" in model_name_or_path.lower():
config["template"] = "baichuan2"
# 保存配置
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(config, f,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
indent=2)
print(f"train_config.yaml created: {config_file}")
print(f"Model will be saved to: {model_dir_name}")
return str(config_file)
except Exception as e:
print(f"Failed to create train_config.yaml: {e}")
return None
def verify_environment():
"""Verify runtime environment"""
print("Checking environment...")
missing_deps = []
try:
import llamafactory
print("✅ LlamaFactory installed")
except ImportError:
missing_deps.append("llamafactory[torch,metrics]")
try:
import yaml
print("✅ PyYAML installed")
except ImportError:
missing_deps.append("pyyaml")
if missing_deps:
print(f"❌ Missing dependencies: {', '.join(missing_deps)}")
print(f"Install with: pip install {' '.join(missing_deps)}")
return False
return True
def check_required_files():
"""Check if required built-in scripts exist"""
# 检查所有需要的内置脚本
required_scripts = [
"path_to_jsonl_script.py",
"llama_factory_trainer.py"
]
missing_scripts = []
for script in required_scripts:
script_path = get_dataflow_script_path(script)
if script_path is None:
missing_scripts.append(script)
else:
print(f"✅ Found built-in script: {script}")
if missing_scripts:
print(f"❌ Missing built-in scripts: {', '.join(missing_scripts)}")
print("These should be part of the dataflow installation")
return False
# 检查用户目录下是否有可自定义的脚本
current_dir = Path(os.getcwd())
customizable_script = current_dir / "pdf_to_qa_pipeline.py"
if customizable_script.exists():
print("✅ Found customizable script: pdf_to_qa_pipeline.py")
else:
print("❌ Missing customizable script: pdf_to_qa_pipeline.py")
print("Run 'dataflow pdf2model init' first")
return False
return True
def cli_pdf2model_init(cache_path: str = "./", model_name: str = "Qwen/Qwen2.5-7B-Instruct") -> bool:
"""
PDF2Model initialization:
0. Copy only customizable scripts to current directory
1. Create train_config.yaml in .cache directory
"""
print("Starting PDF2Model initialization...")
print(f"Cache directory: {cache_path}")
print(f"Model: {model_name}")
print(f"Output directory: pdf2model_cache_<timestamp>") # 更新输出目录显示
print("-" * 60)
if not verify_environment():
return False
try:
# Step 0: Copy only customizable scripts
if not copy_customizable_scripts():
return False
# Step 1: Create training configuration
print("Step 1: Creating training configuration...")
config_file = create_train_config_yaml(cache_path, model_name)
if config_file:
print("PDF2Model initialization completed!")
return True
else:
print("Failed to create training configuration")
return False
except Exception as e:
print(f"Initialization failed: {e}")
return False
def get_latest_model_dir(cache_path_obj):
"""获取最新的模型目录(基于时间戳)"""
saves_dir = cache_path_obj / ".cache" / "saves"
if not saves_dir.exists():
return None
# 查找所有 pdf2model_cache_ 开头的目录
model_dirs = []
for dir_path in saves_dir.iterdir():
if dir_path.is_dir() and dir_path.name.startswith('pdf2model_cache_'):
# 检查是否包含正确的时间戳格式 (YYYYMMDD_HHMMSS)
timestamp_part = dir_path.name.replace('pdf2model_cache_', '')
if len(timestamp_part) == 15 and timestamp_part[8] == '_':
date_part = timestamp_part[:8]
time_part = timestamp_part[9:]
if date_part.isdigit() and time_part.isdigit() and len(time_part) == 6:
model_dirs.append(dir_path)
if not model_dirs:
return None
# 按名称排序(时间戳会自然排序)
model_dirs.sort(key=lambda x: x.name, reverse=True)
return model_dirs[0]
def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: str = "./") -> bool:
"""
Start PDF2Model training using mix of built-in and user scripts
"""
print("Starting PDF2Model training...")
current_dir = Path(os.getcwd())
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
cache_path_obj = current_dir / cache_path_obj
config_path_obj = Path(lf_yaml)
if not config_path_obj.is_absolute():
config_path_obj = current_dir / config_path_obj
if not verify_environment():
return False
if not check_required_files():
return False
if not config_path_obj.exists():
print(f"Training config file not found: {config_path_obj}")
print(f"{Style.BRIGHT}Run 'dataflow pdf2model init' first")
return False
print("-" * 60)
try:
# Step 1: PDF Detection
script1_path = get_dataflow_script_path("path_to_jsonl_script.py")
args1 = ["./", "--output", str(cache_path_obj / ".cache" / "gpu" / "pdf_list.jsonl")]
if not run_script_with_args(script1_path, "Step 1: PDF Detection", args1, cwd=str(current_dir)):
return False
# Step 2: Data Processing
script2 = current_dir / "pdf_to_qa_pipeline.py"
args2 = ["--cache", cache_path]
if not run_script_with_args(script2, "Step 2: Data Processing", args2, cwd=str(current_dir)):
return False
# Step 2.5: Create dataset_info.json (dynamically)
print(f"\n{Fore.BLUE}Step 2.5: Creating dataset_info.json{Style.RESET_ALL}")
# 读取训练配置,获取数据集名称
try:
with open(config_path_obj, 'r', encoding='utf-8') as f:
train_config = yaml.safe_load(f)
# 获取数据集名称
dataset_name = train_config.get('dataset')
if isinstance(dataset_name, list):
dataset_name = dataset_name[0] # 如果是列表,取第一个
if not dataset_name:
print("Warning: No dataset name found in train_config.yaml, using default 'kb_qa'")
dataset_name = 'kb_qa'
print(f"Dataset name from config: {dataset_name}")
except Exception as e:
print(f"Warning: Could not read train_config.yaml: {e}")
print("Using default dataset name: kb_qa")
dataset_name = 'kb_qa'
# 创建 dataset_info.json
dataset_info_path = cache_path_obj / ".cache" / "data" / "dataset_info.json"
dataset_info_path.parent.mkdir(parents=True, exist_ok=True)
dataset_info = {
dataset_name: { # ← 使用从配置读取的名称
"file_name": "qa.json",
"formatting": "alpaca",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output"
}
}
}
with open(dataset_info_path, 'w', encoding='utf-8') as f:
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
print(f"Created: {dataset_info_path}")
print(f"Dataset registered as: {dataset_name}")
print(f"{Fore.GREEN}✅ Step 2.5: Creating dataset_info.json completed{Style.RESET_ALL}")
# Step 3: Data Conversion - skip
print(f"\n{Fore.BLUE}Step 3: Data Conversion{Style.RESET_ALL}")
qa_json_path = cache_path_obj / ".cache" / "data" / "qa.json"
if qa_json_path.exists():
print(f"✅ qa.json already in correct format, skipping conversion")
print(f"{Fore.GREEN}✅ Step 3: Data Conversion completed{Style.RESET_ALL}")
else:
print(f"❌ qa.json not found at {qa_json_path}")
return False
# Step 4: Training
script4_path = get_dataflow_script_path("llama_factory_trainer.py")
args4 = ["--config", str(config_path_obj), "--cache", cache_path]
if not run_script_with_args(script4_path, "Step 4: Training", args4, cwd=str(current_dir)):
return False
# Show completion info
try:
with open(config_path_obj, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
actual_output_dir = config.get('output_dir', 'unknown')
except:
actual_output_dir = 'unknown'
print("Training completed successfully!")
print(f"Model saved to: {actual_output_dir}")
print("Next steps:")
print(f"{Style.BRIGHT}Test the trained model with 'dataflow chat'")
return True
except Exception as e:
print(f"Training error: {e}")
return False
def cli_pdf2model_chat(model_path=None, cache_path="./", base_model=None):
"""Start LlamaFactory chat interface"""
current_dir = Path(os.getcwd())
# 处理cache路径
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
cache_path_obj = current_dir / cache_path_obj
# 确定模型路径
if model_path is None:
# 获取最新的模型目录
latest_model_dir = get_latest_model_dir(cache_path_obj)
if latest_model_dir:
model_path = latest_model_dir
else:
print("No trained model found")
print("Run 'dataflow pdf2model train' to train a model first")
return False
model_path = Path(model_path)
if not model_path.exists():
print(f"Model not found: {model_path}")
print("Run 'dataflow pdf2model train' to train a model first")
return False
# 验证是否为有效的adapter目录
adapter_files = [
"adapter_config.json",
"adapter_model.bin",
"adapter_model.safetensors"
]
has_adapter = any((model_path / f).exists() for f in adapter_files)
if not has_adapter:
print(f"No adapter files found in {model_path}")
print("This doesn't appear to be a trained adapter directory.")
print("Expected files: adapter_config.json, adapter_model.bin/safetensors")
return False
# 确定基础模型路径 - 安全的读取方式
if base_model is None:
base_model = None # 先设为None
# 尝试从训练配置中读取基础模型
config_file = cache_path_obj / ".cache" / "train_config.yaml"
if config_file.exists():
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
base_model = config.get('model_name_or_path')
if base_model:
print(f"Found base model in config: {base_model}")
except Exception as e:
print(f"Warning: Could not read config file: {e}")
# 尝试从adapter_config.json读取
if not base_model:
adapter_config_path = model_path / "adapter_config.json"
if adapter_config_path.exists():
try:
with open(adapter_config_path, 'r', encoding='utf-8') as f:
adapter_config = json.load(f)
base_model = adapter_config.get('base_model_name_or_path')
if base_model:
print(f"Found base model in adapter config: {base_model}")
except Exception as e:
print(f"Warning: Could not read adapter config: {e}")
# 如果仍然没有找到base_model,报错退出而不是使用默认值
if not base_model:
print("Cannot determine base model path")
print("Please ensure your training config contains 'model_name_or_path'")
print("Or check that adapter_config.json exists and contains 'base_model_name_or_path'")
return False
# 检查LlamaFactory
try:
import llamafactory
print("LlamaFactory available")
except ImportError:
print("LlamaFactory not installed")
print("Install with: pip install llamafactory[torch,metrics]")
return False
# 直接用命令行参数启动聊天
chat_cmd = [
"llamafactory-cli", "chat",
"--model_name_or_path", base_model,
"--adapter_name_or_path", str(model_path.absolute())
]
print(f"Base model: {base_model}")
print(f"Adapter path: {model_path}")
print(f"Command: {' '.join(chat_cmd)}")
print("-" * 60)
print("Starting chat session...")
print("-" * 60)
try:
result = subprocess.run(chat_cmd, check=True)
print("\nChat session completed")
return True
except subprocess.CalledProcessError as e:
print(f"\nChat failed: {e}")
return False
except KeyboardInterrupt:
print("\n\nChat session ended by user")
return True
\ No newline at end of file
#!/usr/bin/env python3
"""
DataFlow Text Processing CLI Module - dataflow/cli_funcs/cli_text.py
Text data processing pipeline with complete workflow including Text2QA
"""
import subprocess
import sys
import json
import os
import datetime
from pathlib import Path
from typing import List, Union, Any
from colorama import Fore, Style
from dataflow import get_logger
from .paths import DataFlowPath
logger = get_logger()
def run_script_with_args(script_path: Path, description: str, args: list = None, cwd: str = None) -> bool:
"""Run a Python script with arguments and real-time output"""
print(f"\n{Fore.BLUE}{description}{Style.RESET_ALL}")
cmd = [sys.executable, str(script_path)]
if args:
cmd.extend(args)
print(f"Running: {' '.join(cmd)}")
if cwd:
print(f"Working directory: {cwd}")
try:
result = subprocess.run(cmd, cwd=cwd, check=True,
stdout=sys.stdout, stderr=sys.stderr, text=True)
print(f"{Fore.GREEN}{description} completed{Style.RESET_ALL}")
return True
except subprocess.CalledProcessError as e:
print(f"{Fore.RED}{description} failed{Style.RESET_ALL}")
return False
def get_dataflow_script_path(script_name: str) -> Path:
"""Get the path of dataflow built-in scripts"""
try:
import dataflow
dataflow_path = Path(dataflow.__file__).parent
# Text2Model 脚本在 dataflow/cli_funcs/text2model_pipeline/ 目录下
text2model_path = dataflow_path / "cli_funcs" / "text2model_pipeline" / script_name
if text2model_path.exists():
return text2model_path
# 检查其他可能的路径
possible_dirs = [
dataflow_path / "templates" / "text2model_pipeline",
dataflow_path / "pipeline_templates"
]
for dir_path in possible_dirs:
script_path = dir_path / script_name
if script_path.exists():
return script_path
return None
except:
return None
def copy_customizable_scripts():
"""Copy scripts that users might want to customize"""
print("Step 0: Setting up customizable pipeline scripts...")
current_dir = Path(os.getcwd())
# 检查当前目录下是否已经存在所需的脚本文件
required_scripts = [
"text_to_qa_pipeline.py",
]
existing_scripts = []
missing_scripts = []
for script_name in required_scripts:
script_path = current_dir / script_name
if script_path.exists():
existing_scripts.append(script_name)
print(f"Found existing: {script_name}")
else:
missing_scripts.append(script_name)
# 尝试从模板复制缺失的脚本
copied_files = []
for script_name in missing_scripts:
source_path = get_dataflow_script_path(script_name)
if source_path is not None:
try:
import shutil
target_file = current_dir / script_name
shutil.copy2(source_path, target_file)
copied_files.append(script_name)
print(f"Copied from template: {script_name}")
except Exception as e:
print(f"Warning: Failed to copy {script_name}: {e}")
else:
print(f"Warning: Template not found for {script_name}")
total_available = len(existing_scripts) + len(copied_files)
if total_available > 0:
print(f"Setup completed: {total_available} scripts available")
if existing_scripts:
print(f" Existing scripts: {', '.join(existing_scripts)}")
if copied_files:
print(f" Copied from templates: {', '.join(copied_files)}")
return True
else:
print("Warning: No pipeline scripts available")
return False
def create_train_config_yaml(cache_path="./", model_name_or_path="Qwen/Qwen2.5-7B-Instruct"):
"""Create train_config.yaml file using built-in LlamaFactory configuration"""
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path_obj = caller_cwd / cache_path_obj
# 生成时间戳
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir_name = f"text2model_cache_{timestamp}"
cache_dir = cache_path_obj / ".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
config_file = cache_dir / "train_config.yaml"
try:
# 使用内置的 LlamaFactory.py 获取默认配置
llamafactory_script_path = get_dataflow_script_path("llama_factory_trainer.py")
if llamafactory_script_path is None:
print("Built-in llama_factory_trainer.py not found")
return None
import importlib.util
spec = importlib.util.spec_from_file_location("llamafactory_trainer", llamafactory_script_path)
llamafactory_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(llamafactory_module)
# 创建trainer实例并获取默认配置
trainer = llamafactory_module.LlamaFactoryTrainer(str(config_file), str(cache_path_obj))
config = trainer.get_default_config()
# 只更新必要的动态参数
config["model_name_or_path"] = model_name_or_path
config["output_dir"] = str(cache_path_obj / ".cache" / "saves" / model_dir_name)
config["dataset_dir"] = str(cache_path_obj / ".cache" / "data")
# 根据模型类型设置模板
if "qwen" in model_name_or_path.lower():
config["template"] = "qwen"
elif "llama" in model_name_or_path.lower():
config["template"] = "llama3"
elif "chatglm" in model_name_or_path.lower():
config["template"] = "chatglm3"
elif "baichuan" in model_name_or_path.lower():
config["template"] = "baichuan2"
# 保存配置
import yaml
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(config, f,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
indent=2)
print(f"train_config.yaml created: {config_file}")
print(f"Model will be saved to: {model_dir_name}")
return str(config_file)
except Exception as e:
print(f"Failed to create train_config.yaml: {e}")
return None
def verify_environment():
"""Verify runtime environment"""
print("Checking environment...")
missing_deps = []
try:
import llamafactory
print("✅ LlamaFactory installed")
except ImportError:
missing_deps.append("llamafactory[torch,metrics]")
try:
import yaml
print("✅ PyYAML installed")
except ImportError:
missing_deps.append("pyyaml")
try:
from dataflow.utils.storage import FileStorage
print("✅ DataFlow storage available")
except ImportError:
missing_deps.append("dataflow")
try:
# 修复: 使用正确的算子导入路径和类名
from dataflow.operators.knowledge_cleaning import (
KBCChunkGeneratorBatch as CorpusTextSplitterBatch,
KBCTextCleanerBatch as KnowledgeCleanerBatch,
KBCMultiHopQAGeneratorBatch as MultiHopQAGeneratorBatch
)
print("✅ DataFlow operators available")
except ImportError:
missing_deps.append("dataflow operators")
if missing_deps:
print(f"❌ Missing dependencies: {', '.join(missing_deps)}")
print(f"Install with: pip install {' '.join(missing_deps)}")
return False
return True
def check_required_files_for_training():
"""Check if required built-in scripts exist for training"""
# 检查所有需要的内置脚本
required_scripts = [
"merge_json_jsonl.py",
"llama_factory_trainer.py"
]
missing_scripts = []
for script in required_scripts:
script_path = get_dataflow_script_path(script)
if script_path is None:
missing_scripts.append(script)
else:
print(f"✅ Found built-in script: {script}")
if missing_scripts:
print(f"❌ Missing built-in scripts: {', '.join(missing_scripts)}")
print("These should be part of the dataflow installation")
return False
# 检查用户目录下是否有可自定义的脚本
current_dir = Path(os.getcwd())
customizable_scripts = [
"text_to_qa_pipeline.py",
]
missing_customizable = []
for script_name in customizable_scripts:
script_path = current_dir / script_name
if script_path.exists():
print(f"✅ Found customizable script: {script_name}")
else:
missing_customizable.append(script_name)
if missing_customizable:
print(f"❌ Missing customizable scripts: {', '.join(missing_customizable)}")
print("Run 'dataflow text2model init' first")
return False
return True
def analyze_input_data(input_file: str) -> dict:
"""分析输入数据的字段结构"""
if not input_file or not Path(input_file).exists():
return {}
try:
with open(input_file, 'r', encoding='utf-8') as f:
first_line = f.readline().strip()
if first_line:
sample_data = json.loads(first_line)
return {
'available_keys': list(sample_data.keys()),
'has_sft_format': all(key in sample_data for key in ['instruction', 'input', 'output']),
'has_text_field': 'text' in sample_data,
'has_raw_content': 'raw_content' in sample_data
}
except Exception as e:
print(f"Could not analyze input file: {e}")
return {}
def get_latest_model_dir(cache_path_obj):
"""获取最新的模型目录(基于时间戳)"""
saves_dir = cache_path_obj / ".cache" / "saves"
if not saves_dir.exists():
return None
# 查找所有 text2model_cache_ 开头的目录
model_dirs = []
for dir_path in saves_dir.iterdir():
if dir_path.is_dir() and dir_path.name.startswith('text2model_cache_'):
# 检查是否包含正确的时间戳格式 (YYYYMMDD_HHMMSS)
timestamp_part = dir_path.name.replace('text2model_cache_', '')
if len(timestamp_part) == 15 and timestamp_part[8] == '_':
date_part = timestamp_part[:8]
time_part = timestamp_part[9:]
if date_part.isdigit() and time_part.isdigit() and len(time_part) == 6:
model_dirs.append(dir_path)
if not model_dirs:
return None
# 按名称排序(时间戳会自然排序)
model_dirs.sort(key=lambda x: x.name, reverse=True)
return model_dirs[0]
def cli_text2model_init(cache_path: str = "./") -> bool:
"""
Text2Model initialization:
0. Check for existing scripts and copy any missing templates
1. Create train_config.yaml in .cache directory
"""
if not verify_environment():
return False
try:
# Step 0: Check for existing scripts and setup missing ones
if not copy_customizable_scripts():
print("Warning: Some scripts may be missing, but continuing...")
# Step 1: Create training configuration
print("Step 1: Creating training configuration...")
config_file = create_train_config_yaml(cache_path, "Qwen/Qwen2.5-7B-Instruct")
if config_file:
print("Text2Model initialization completed!")
print("\nWorkflow:")
print("1. Put your JSON/JSONL files with 'text' field in current directory")
print("2. Run: dataflow text2model train")
print(" This will automatically run Text2QA generation and training")
return True
else:
print("Failed to create training configuration")
return False
except Exception as e:
print(f"Initialization failed: {e}")
return False
def cli_text2model_train(input_keys: str = None, lf_yaml: str = "./.cache/train_config.yaml") -> bool:
"""
Start Text2Model training using complete pipeline
"""
print("Starting Text2Model training...")
if input_keys:
print(f"Processing fields: {input_keys}")
current_dir = Path(os.getcwd())
config_path_obj = Path(lf_yaml)
if not config_path_obj.is_absolute():
config_path_obj = current_dir / config_path_obj
if not verify_environment():
return False
if not config_path_obj.exists():
print(f"Training config file not found: {config_path_obj}")
print("Run 'dataflow text2model init' first")
return False
input_dir = "./"
cache_path_obj = current_dir
input_path = Path(input_dir)
if not input_path.is_absolute():
input_path = current_dir / input_path
if not input_path.exists():
print(f"Input directory not found: {input_path}")
return False
print("-" * 60)
try:
# Step 1: Merge JSON/JSONL files to create text_input.jsonl
print(f"{Fore.CYAN}Step 1: Merging JSON/JSONL files...{Style.RESET_ALL}")
# 调用 merge_json_jsonl.py 的逻辑
script1_path = get_dataflow_script_path("merge_json_jsonl.py")
args1 = [str(input_path), "--cache", str(cache_path_obj)]
if not run_script_with_args(script1_path, "JSON/JSONL merging", args1, cwd=str(current_dir)):
print(f"{Fore.RED}❌ Step 1: JSON/JSONL merging failed{Style.RESET_ALL}")
return False
# 验证 text_input.jsonl 是否创建成功
text_input_file = cache_path_obj / ".cache" / "gpu" / "text_input.jsonl"
if not text_input_file.exists():
print(
f"{Fore.RED}❌ text_input.jsonl not created. Check if you have JSON/JSONL files in {input_path}{Style.RESET_ALL}")
return False
file_size = text_input_file.stat().st_size
print(f"{Fore.GREEN}✅ Step 1 completed: {text_input_file} ({file_size} bytes){Style.RESET_ALL}")
# Step 2: Text2QA Pipeline
print(f"{Fore.CYAN}Step 2: Text2QA generation...{Style.RESET_ALL}")
script2_path = cache_path_obj / "text_to_qa_pipeline.py"
args2 = ["--cache", str(cache_path_obj)]
if not run_script_with_args(script2_path, "Text2QA generation", args2, cwd=str(current_dir)):
print(f"{Fore.RED}❌ Step 2: Text2QA generation failed{Style.RESET_ALL}")
return False
# 验证 Text2QA 输出
qa_output_file = cache_path_obj / ".cache" / "gpu" / "text2qa_step_step3.json"
if not qa_output_file.exists():
print(f"{Fore.RED}❌ Text2QA output not found{Style.RESET_ALL}")
return False
file_size = qa_output_file.stat().st_size
print(f"{Fore.GREEN}✅ Step 2 completed: {qa_output_file} ({file_size} bytes){Style.RESET_ALL}")
# Step 3: Convert to training format
print(f"{Fore.CYAN}Step 3: Converting to training format...{Style.RESET_ALL}")
script3_path = get_dataflow_script_path("merge_filter_qa_pairs.py")
args3 = ["--cache", str(cache_path_obj)]
if not run_script_with_args(script3_path, "QA format conversion", args3, cwd=str(current_dir)):
print(f"{Fore.RED}❌ Step 3: QA format conversion failed{Style.RESET_ALL}")
return False
# 验证训练数据
qa_file = cache_path_obj / ".cache" / "data" / "qa.json"
dataset_info_file = cache_path_obj / ".cache" / "data" / "dataset_info.json"
if not qa_file.exists() or not dataset_info_file.exists():
print(f"{Fore.RED}❌ Training data files not created{Style.RESET_ALL}")
return False
# 统计样本数
try:
import json
with open(qa_file, 'r', encoding='utf-8') as f:
qa_data = json.load(f)
sample_count = len(qa_data)
file_size = qa_file.stat().st_size
print(
f"{Fore.GREEN}✅ Step 3 completed: {sample_count} training samples ({file_size} bytes){Style.RESET_ALL}")
except:
print(f"{Fore.GREEN}✅ Step 3 completed{Style.RESET_ALL}")
# Step 4: Training
print(f"{Fore.CYAN}Step 4: Starting model training...{Style.RESET_ALL}")
script4_path = get_dataflow_script_path("llama_factory_trainer.py")
args4 = ["--config", str(config_path_obj), "--cache", str(cache_path_obj)]
if not run_script_with_args(script4_path, "Model training", args4, cwd=str(current_dir)):
print(f"{Fore.RED}❌ Step 4: Training failed{Style.RESET_ALL}")
return False
print(f"{Fore.GREEN}✅ Text2Model training completed successfully!{Style.RESET_ALL}")
print(f"Next steps:")
print(f" Test the model: dataflow chat")
return True
except Exception as e:
print(f"Training error: {e}")
return False
def _run_text2qa_workflow(current_dir: Path, cache_path_obj: Path, config_path_obj: Path) -> bool:
"""Run Text2QA workflow"""
# Step 1: Check if Text2QA output exists
text2qa_output = cache_path_obj / ".cache" / "gpu" / "text2qa_step_step3.json"
if not text2qa_output.exists():
print("Text2QA output not found. Please run text_to_qa_pipeline.py first.")
print("Example:")
print(" 1. Prepare JSON/JSONL files with 'text' field in current directory")
print(" 2. Run: python text_to_qa_pipeline.py")
print(" 3. Then run: dataflow text2model train --text2qa")
return False
print("Found Text2QA output, proceeding with conversion...")
# Step 2: Convert QA to Alpaca format
script2 = current_dir / "merge_filter_qa_pairs.py"
args2 = ["--cache", str(cache_path_obj)]
if not run_script_with_args(script2, "Step 2: Converting QA to Alpaca format", args2, cwd=str(current_dir)):
return False
# Step 3: Training
script3_path = get_dataflow_script_path("llama_factory_trainer.py")
args3 = ["--config", str(config_path_obj)]
if not run_script_with_args(script3_path, "Step 3: Training", args3, cwd=str(current_dir)):
return False
# 显示训练完成信息
try:
import yaml
with open(config_path_obj, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
actual_output_dir = config.get('output_dir', 'unknown')
except:
actual_output_dir = 'unknown'
print("Text2QA training completed successfully!")
print(f"Model saved to: {actual_output_dir}")
print("Next steps:")
print("Test the trained model with 'dataflow chat'")
return True
# def _run_normal_text_workflow(input_path: Path, current_dir: Path, cache_path_obj: Path, config_path_obj: Path,
# input_keys: str) -> bool:
# """Run Text2QA workflow as the main processing pipeline"""
# # Step 1: Merge JSON/JSONL files and create text_input.jsonl - 使用内置脚本
# script1_path = get_dataflow_script_path("merge_json_jsonl.py")
# args1 = [str(input_path), "--cache", str(cache_path_obj)]
# if not run_script_with_args(script1_path, "Step 1: Preparing text input for Text2QA", args1, cwd=str(current_dir)):
# return False
# # Step 2: Run Text2QA Pipeline - 使用用户目录下的脚本
# script2 = current_dir / "text_to_qa_pipeline.py"
# args2 = ["--cache", str(cache_path_obj)]
# if not run_script_with_args(script2, "Step 2: Text2QA generation", args2, cwd=str(current_dir)):
# return False
# # Step 3: Convert QA to Alpaca format - 使用用户目录下的脚本
# script3 = current_dir / "merge_filter_qa_pairs.py"
# args3 = ["--cache", str(cache_path_obj)]
# if not run_script_with_args(script3, "Step 3: Converting QA to training format", args3, cwd=str(current_dir)):
# return False
# # Step 4: Training - 使用内置脚本
# script4_path = get_dataflow_script_path("llama_factory_trainer.py")
# args4 = ["--config", str(config_path_obj)]
# if not run_script_with_args(script4_path, "Step 4: Training", args4, cwd=str(current_dir)):
# return False
# # 显示训练完成信息,从配置文件中读取实际的输出目录
# try:
# import yaml
# with open(config_path_obj, 'r', encoding='utf-8') as f:
# config = yaml.safe_load(f)
# actual_output_dir = config.get('output_dir', 'unknown')
# except:
# actual_output_dir = 'unknown'
# print("Text2QA training completed successfully!")
# print(f"Model saved to: {actual_output_dir}")
# print("Next steps:")
# print("Test the trained model with 'dataflow chat'")
# return True
def cli_text2model_chat(model_path=None):
"""Start LlamaFactory chat interface for text2model"""
current_dir = Path(os.getcwd())
# 使用默认cache路径
cache_path_obj = current_dir
# 确定模型路径
if model_path is None:
# 获取最新的模型目录
latest_model_dir = get_latest_model_dir(cache_path_obj)
if latest_model_dir:
model_path = latest_model_dir
else:
print("No trained model found")
print("Run 'dataflow text2model train' to train a model first")
return False
model_path = Path(model_path)
if not model_path.exists():
print(f"Model not found: {model_path}")
print("Run 'dataflow text2model train' to train a model first")
return False
# 验证是否为有效的adapter目录
adapter_files = [
"adapter_config.json",
"adapter_model.bin",
"adapter_model.safetensors"
]
has_adapter = any((model_path / f).exists() for f in adapter_files)
if not has_adapter:
print(f"No adapter files found in {model_path}")
print("This doesn't appear to be a trained adapter directory.")
print("Expected files: adapter_config.json, adapter_model.bin/safetensors")
return False
# 安全地确定基础模型
base_model = None
# 尝试从训练配置中读取基础模型
config_file = cache_path_obj / ".cache" / "train_config.yaml"
if config_file.exists():
try:
import yaml
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
base_model = config.get('model_name_or_path')
if base_model:
print(f"Found base model in config: {base_model}")
except Exception as e:
print(f"Warning: Could not read config file: {e}")
# 尝试从adapter_config.json读取
if not base_model:
adapter_config_path = model_path / "adapter_config.json"
if adapter_config_path.exists():
try:
with open(adapter_config_path, 'r', encoding='utf-8') as f:
adapter_config = json.load(f)
base_model = adapter_config.get('base_model_name_or_path')
if base_model:
print(f"Found base model in adapter config: {base_model}")
except Exception as e:
print(f"Warning: Could not read adapter config: {e}")
# 如果仍然没有找到base_model,报错退出
if not base_model:
print("Cannot determine base model path")
print("Please ensure your training config contains 'model_name_or_path'")
print("Or check that adapter_config.json exists and contains 'base_model_name_or_path'")
return False
# 检查LlamaFactory
try:
import llamafactory
print("LlamaFactory available")
except ImportError:
print("LlamaFactory not installed")
print("Install with: pip install llamafactory[torch,metrics]")
return False
# 直接用命令行参数启动聊天
chat_cmd = [
"llamafactory-cli", "chat",
"--model_name_or_path", base_model,
"--adapter_name_or_path", str(model_path.absolute())
]
print(f"Base model: {base_model}")
print(f"Adapter path: {model_path}")
print("-" * 60)
print("Starting chat session...")
print("-" * 60)
try:
result = subprocess.run(chat_cmd, check=True)
print("\nChat session completed")
return True
except subprocess.CalledProcessError as e:
print(f"\nChat failed: {e}")
return False
except KeyboardInterrupt:
print("\n\nChat session ended by user")
return True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment