Commit 97e8278b authored by zzg_666's avatar zzg_666
Browse files

适配后端vllm

parents
Pipeline #3071 canceled with stages
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# 明确指定源码文件
recursive-include dataflow *.py
recursive-include dataflow *.json
recursive-include dataflow *.yaml
recursive-include dataflow *.sh
recursive-include dataflow *.txt
# 包含样例文件夹下所有类型的文件
# 样例文件夹路径为dataflow/exaple
recursive-include dataflow/example */
# 包含顶层文档文件
include README.md
include requirements.txt
include LICENSE
# 排除无用缓存文件
global-exclude *.py[cod]
global-exclude __pycache__/
global-exclude *.so
global-exclude *.egg-info/
\ No newline at end of file
# DataFlow
DataFlow是一个数据准备系统,旨在从噪声数据源(PDF、纯文本、低质量问答)中解析,生成,加工并评估高质量数据,以提升大语言模型(LLMs)在特定领域的表现,支持预训练、监督微调(SFT)、强化学习训练以及基于知识库的RAG系统。
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.2 |
| python | 3.10.12 |
| transformers | 4.53.3 |
| vllm | 0.9.2+das.opt1.dtk25042 |
| torch | 2.5.1+das.opt1.dtk25042 |
| torchaudio | 2.5.1+das.opt1.dtk25042 |
| torchvision | 0.20.1+das.opt1.dtk25042 |
| flash_mla | 1.0.0+das.opt1.dtk25042 |
##安装
使用DCU实现推理,后端为vllm,命令如下:
```bash
docker run -it --shm-size 60g --network=host --name dataflow --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro -v $PWD:/home/ image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.2-py3.10 bash
git clone https://developer.sourcefind.cn/codes/qteam/dataflow
cd dataflow
pip install -e .[vllm]
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
安装完成后,你可以用如下指令检查安装是否正确:
```bash
dataflow -v
```
如果安装正常,且DataFlow是最新的Release版,则会看到:
```bash
open-dataflow codebase version: 1.0.7
Checking for updates...
Local version : 1.0.7
PyPI version : 1.0.7
You are using the latest version: 1.0.7
```
## 参考资料
- https://github.com/OpenDCAI/DataFlow
from .utils import *
from .version import __version__, version_info
from .logger import get_logger
from .operators import *
from .prompts import *
__all__ = [
'__version__',
'version_info',
'get_logger',
]
def hello():
return "Hello from open-dataflow!"
\ No newline at end of file
This diff is collapsed.
from .cli_init import cli_init
from .cli_env import cli_env
__all__ = [
"cli_env",
"cli_init"
]
\ No newline at end of file
import os
import torch
import platform
from colorama import init, Fore, Style
from dataflow import __version__
import importlib.metadata
from .paths import DataFlowPath
def is_torch_cuda_available():
"""
Check if CUDA is available for PyTorch.
"""
return torch.cuda.is_available()
def get_env_info():
info = {
"`Dataflow` version": __version__,
"`Dataflow` install path": DataFlowPath.get_dataflow_dir(),
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Torchvision version": torch.__version__,
}
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
info["GPU number"] = torch.cuda.device_count()
info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB"
try:
import deepspeed # type: ignore
info["DeepSpeed version"] = deepspeed.__version__
except Exception:
pass
try:
import bitsandbytes # type: ignore
info["Bitsandbytes version"] = bitsandbytes.__version__
except Exception:
pass
try:
import vllm
info["vLLM version"] = vllm.__version__
except Exception:
pass
try:
import sglang
info["SGLang version"] = sglang.__version__
except Exception:
pass
try:
mineru_version = importlib.metadata.version("mineru")
info["MinerU version"] = mineru_version
except Exception:
pass
try:
import subprocess
# get the dir of imdlbenco package
imdlbenco_dir = os.path.dirname(os.path.abspath(__file__))
# move to this dir and get the git commit hash in a subprocess
# but don't change the current working directory
os.chdir(imdlbenco_dir)
commit_info = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
commit_hash = commit_info.stdout.strip()
info["Git commit"] = commit_hash
except Exception:
pass
print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
print(Fore.BLUE + "=" * os.get_terminal_size().columns + Style.RESET_ALL)
def cli_env():
get_env_info()
if __name__ == "__main__":
print(get_env_info())
\ No newline at end of file
# dataflow/cli_funcs/cli_eval.py
"""DataFlow 评估工具"""
import os
import json
import shutil
import importlib.util
from pathlib import Path
from typing import List, Dict, Any
from datetime import datetime
from dataflow import get_logger
from dataflow.serving import LocalModelLLMServing_vllm
from dataflow.operators.reasoning import ReasoningAnswerGenerator
from dataflow.prompts.reasoning.diy import DiyAnswerGeneratorPrompt
from dataflow.utils.storage import FileStorage
import torch
import gc
logger = get_logger()
DEFAULT_ANSWER_PROMPT = """Please answer the following question based on the provided academic literature. Your response should:
1. Provide accurate information from the source material
2. Include relevant scientific reasoning and methodology
3. Reference specific findings, data, or conclusions when applicable
4. Maintain academic rigor and precision in your explanation
Question: {question}
Answer:"""
class EvaluationPipeline:
"""评估管道"""
def __init__(self, config: Dict[str, Any]):
self.config = config
# self.cli_args = cli_args
self.prepared_models = []
self.generated_files = []
def run(self) -> bool:
try:
# 1. 获取目标模型
self.target_models = self._get_target_models()
if not self.target_models:
logger.error("No TARGET_MODELS found in config")
return False
self.prepared_models = self._prepare_models()
if not self.prepared_models:
return False
# 2. 生成答案
self.generated_files = self._generate_answers()
if not self.generated_files:
return False
# 3. 执行评估
results = self._run_evaluation()
# 4. 生成报告
self._generate_report(results)
return True
except Exception as e:
logger.error(f"Evaluation failed: {e}")
import traceback
traceback.print_exc()
return False
def _get_target_models(self) -> List:
"""获取目标模型列表"""
target_config = self.config.get("TARGET_MODELS", [])
if not isinstance(target_config, list):
logger.error(f"TARGET_MODELS must be a list, got {type(target_config)}")
return []
if not target_config:
logger.error("TARGET_MODELS is empty")
return []
return target_config
def _prepare_models(self) -> List[Dict]:
"""准备模型信息"""
prepared = []
default_config = self.config.get("DEFAULT_MODEL_CONFIG", {})
for idx, item in enumerate(self.target_models, 1):
if isinstance(item, str):
model_info = {
"name": Path(item).name,
"path": item,
"type": "local",
**default_config
}
elif isinstance(item, dict):
if "path" not in item:
logger.error(f"Model at index {idx} missing 'path'")
continue
model_info = {
**default_config, # 1. 先设置默认值
**item, # 2. 用户配置覆盖默认值
"name": item.get("name", Path(item["path"]).name), # 3. 确保name字段正确
"type": "local" # 4. 强制设置type
}
else:
logger.error(f"Invalid model format at index {idx}")
continue
prepared.append(model_info)
return prepared
def _clear_vllm_cache(self):
"""清理 vLLM 缓存"""
cache_paths = [
Path.home() / ".cache" / "vllm" / "torch_compile_cache",
Path.home() / ".cache" / "vllm"
]
for cache_path in cache_paths:
if cache_path.exists():
try:
shutil.rmtree(cache_path)
except Exception as e:
logger.warning(f"Failed to clear cache: {e}")
def _generate_answers(self) -> List[Dict]:
"""生成模型答案"""
generated_files = []
data_config = self.config.get("DATA_CONFIG", {})
input_file = data_config.get("input_file", "./.cache/data/qa.json")
if not Path(input_file).exists():
logger.error(f"Input file not found: {input_file}")
return []
self._clear_vllm_cache()
for idx, model_info in enumerate(self.prepared_models, 1):
llm_serving = None
answer_generator = None
storage = None
try:
logger.info(f"[{idx}/{len(self.prepared_models)}] Processing: {model_info['name']}")
cache_dir = model_info.get('cache_dir', './.cache/eval')
Path(cache_dir).mkdir(parents=True, exist_ok=True)
output_file = f"{cache_dir}/answers_{model_info['name']}.json"
# 加载模型
llm_serving = LocalModelLLMServing_vllm(
hf_model_name_or_path=model_info['path'],
vllm_tensor_parallel_size=model_info.get('tensor_parallel_size', 2),
vllm_max_tokens=model_info.get('max_tokens', 1024),
vllm_gpu_memory_utilization=model_info.get('gpu_memory_utilization', 0.8)
)
# 答案生成器
custom_prompt = model_info.get('answer_prompt', DEFAULT_ANSWER_PROMPT)
answer_generator = ReasoningAnswerGenerator(
llm_serving=llm_serving,
prompt_template=DiyAnswerGeneratorPrompt(custom_prompt)
)
# 存储
cache_path = f"{cache_dir}/{model_info['name']}_generation"
storage = FileStorage(
first_entry_file_name=input_file,
cache_path=cache_path,
file_name_prefix=model_info.get('file_prefix', 'answer_gen'),
cache_type=model_info.get('cache_type', 'json')
)
# 运行生成
answer_generator.run(
storage=storage.step(),
input_key=data_config.get("question_key", "input"),
output_key=model_info.get('output_key', 'model_generated_answer')
)
# 保存结果
file_prefix = model_info.get('file_prefix', 'answer_gen')
cache_type = model_info.get('cache_type', 'json')
# 查找所有匹配的文件
pattern = f"{file_prefix}_step*.{cache_type}"
matching_files = sorted(Path(cache_path).glob(pattern))
if matching_files:
# 使用最新的文件(最后一个step)
gen_file = matching_files[-1]
shutil.copy2(gen_file, output_file)
generated_files.append({
"model_name": model_info['name'],
"model_path": model_info['path'],
"file_path": output_file
})
else:
logger.error(f"No generated file found for {model_info['name']} in {cache_path}")
continue
except Exception as e:
logger.error(f"Failed to process {model_info['name']}: {e}")
continue
finally:
if answer_generator is not None:
del answer_generator
if storage is not None:
del storage
if llm_serving is not None:
del llm_serving
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
return generated_files
def _run_evaluation(self) -> List[Dict]:
"""运行评估"""
try:
judge_serving = self.config["create_judge_serving"]()
except Exception as e:
logger.error(f"Failed to create judge: {e}")
return []
results = []
eval_config = self.config.get("EVALUATOR_RUN_CONFIG", {})
for file_info in self.generated_files:
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
result_file = f"./eval_results/{timestamp}_{file_info['model_name']}/result.json"
Path(result_file).parent.mkdir(parents=True, exist_ok=True)
storage = self.config["create_storage"](
file_info["file_path"],
f"./.cache/eval/{file_info['model_name']}"
)
evaluator = self.config["create_evaluator"](judge_serving, result_file)
evaluator.run(
storage=storage.step(),
input_test_answer_key=eval_config.get("input_test_answer_key", "model_generated_answer"),
input_gt_answer_key=eval_config.get("input_gt_answer_key", "output"),
input_question_key=eval_config.get("input_question_key", "input")
)
if Path(result_file).exists():
with open(result_file, 'r') as f:
data = json.load(f)
if data:
data[0]["model_name"] = file_info['model_name']
results.append(data[0])
except Exception as e:
logger.error(f"Eval failed for {file_info['model_name']}: {e}")
continue
return results
def _generate_report(self, results: List[Dict]):
"""生成报告"""
if not results:
logger.warning("No results")
return
sorted_results = sorted(results, key=lambda x: x.get("accuracy", 0), reverse=True)
print("\n" + "=" * 60)
print("Model Evaluation Results")
print("=" * 60)
for i, r in enumerate(sorted_results, 1):
print(f"{i}. {r['model_name']}")
print(f" Accuracy: {r.get('accuracy', 0):.3f}")
print(f" Total: {r.get('total_samples', 0)}")
print(f" Matched: {r.get('matched_samples', 0)}")
print()
print("=" * 60)
# 保存详细报告
report_file = "./eval_results/report.json"
Path(report_file).parent.mkdir(parents=True, exist_ok=True)
with open(report_file, 'w') as f:
json.dump({"results": sorted_results}, f, indent=2)
print(f"Detailed report: {report_file}")
class DataFlowEvalCLI:
"""CLI工具"""
def __init__(self):
self.current_dir = Path.cwd()
def _get_template_path(self, eval_type: str) -> Path:
current_file = Path(__file__)
dataflow_dir = current_file.parent.parent
return dataflow_dir / "cli_funcs" / "eval_pipeline" / f"eval_{eval_type}.py"
def init_eval_files(self):
"""初始化配置文件"""
files = [("eval_api.py", "api"), ("eval_local.py", "local")]
existing = [f for f, _ in files if (self.current_dir / f).exists()]
if existing:
if input(f"{', '.join(existing)} exists. Overwrite? (y/n): ").lower() != 'y':
return False
for filename, eval_type in files:
try:
template = self._get_template_path(eval_type)
if not template.exists():
logger.error(f"Template not found: {template}")
continue
shutil.copy2(template, self.current_dir / filename)
logger.info(f"Created: {filename}")
except Exception as e:
logger.error(f"Failed: {e}")
logger.info("You must modified the eval_api.py or eval_local.py before you run dataflow eval api/local")
return True
def run_eval_file(self, eval_file: str):
"""运行评估"""
config_path = self.current_dir / eval_file
if not config_path.exists():
logger.error(f"Config not found: {eval_file}")
return False
try:
spec = importlib.util.spec_from_file_location("config", config_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
config = module.get_evaluator_config()
return run_evaluation(config)
except Exception as e:
logger.error(f"Failed: {e}")
return False
def run_evaluation(config):
"""运行评估"""
pipeline = EvaluationPipeline(config)
return pipeline.run()
\ No newline at end of file
import os
import re
from colorama import init, Fore, Style
from .paths import DataFlowPath
from .copy_funcs import copy_files_without_recursion, copy_file, copy_files_recursively
def _copy_scripts():
target_dir = os.getcwd()
if not os.path.exists(target_dir):
os.makedirs(target_dir)
# script_path = DataFlowPath.get_dataflow_scripts_dir()
copy_files_recursively(DataFlowPath.get_dataflow_scripts_dir(), target_dir)
def _copy_pipelines():
target_dir = os.getcwd()
if not os.path.exists(target_dir):
os.makedirs(target_dir)
copy_files_recursively(DataFlowPath.get_dataflow_pipelines_dir(), target_dir)
# Copy pipelines
def _copy_playground():
target_dir = os.getcwd()
if not os.path.exists(target_dir):
os.makedirs(target_dir)
copy_files_recursively(DataFlowPath.get_dataflow_playground_dir(), target_dir)
def _copy_examples():
target_dir = os.path.join(os.getcwd(), "example_data")
if not os.path.exists(target_dir):
os.makedirs(target_dir)
copy_files_recursively(DataFlowPath.get_dataflow_example_dir(), target_dir)
def cli_init(subcommand):
print(f'{Fore.GREEN}Initializing in current working directory...{Style.RESET_ALL}')
# base initialize that only contain default scripts
if subcommand == "base":
_copy_pipelines()
_copy_examples()
_copy_playground()
# if subcommand == "model_zoo":
# _copy_train_scripts()
# _copy_demo_runs()
# _copy_demo_configs()
# _copy_dataset_json()
# # base initialize that only contain default scripts
# if subcommand == "backbone":
# _copy_train_scripts()
# _copy_demo_runs()
# _copy_demo_configs()
# _copy_dataset_json()
# print(f'{Fore.GREEN}Successfully initialized IMDLBenCo scripts.{Style.RESET_ALL}')
\ No newline at end of file
#!/usr/bin/env python3
"""
DataFlow PDF2Model CLI Module - dataflow/cli_funcs/cli_pdf.py
PDF to Model training pipeline with init/train/chat commands
"""
import subprocess
import sys
import yaml
import json
import os
import datetime
from pathlib import Path
from colorama import Fore, Style
from dataflow import get_logger
from .paths import DataFlowPath
logger = get_logger()
def run_script_with_args(script_path: Path, description: str, args: list = None, cwd: str = None) -> bool:
"""Run a Python script with arguments and real-time output"""
print(f"\n{Fore.BLUE}{description}{Style.RESET_ALL}")
cmd = [sys.executable, str(script_path)]
if args:
cmd.extend(args)
print(f"Running: {' '.join(cmd)}")
if cwd:
print(f"Working directory: {cwd}")
try:
result = subprocess.run(cmd, cwd=cwd, check=True,
stdout=sys.stdout, stderr=sys.stderr, text=True)
print(f"{Fore.GREEN}{description} completed{Style.RESET_ALL}")
return True
except subprocess.CalledProcessError as e:
print(f"{Fore.RED}{description} failed{Style.RESET_ALL}")
return False
def get_dataflow_script_path(script_name: str) -> Path:
"""Get the path of dataflow built-in scripts"""
try:
import dataflow
dataflow_path = Path(dataflow.__file__).parent
# PDF2Model 脚本在 dataflow/cli_funcs/pdf2model_pipeline/ 目录下
pdf2model_path = dataflow_path / "cli_funcs" / "pdf2model_pipeline" / script_name
if pdf2model_path.exists():
return pdf2model_path
# 检查其他可能的路径
possible_dirs = [
dataflow_path / "templates" / "pdf2model_pipeline",
dataflow_path / "pipeline_templates"
]
for dir_path in possible_dirs:
script_path = dir_path / script_name
if script_path.exists():
return script_path
return None
except:
return None
def copy_customizable_scripts():
"""Only copy scripts that users might want to customize"""
print("Step 0: Copying customizable pipeline script...")
current_dir = Path(os.getcwd())
try:
# 只复制用户可能需要自定义的脚本
scripts_to_copy = [
"pdf_to_qa_pipeline.py" # 用户可能需要修改 vLLM/sglang 配置
]
import shutil
copied_files = []
for script_name in scripts_to_copy:
source_path = get_dataflow_script_path(script_name)
if source_path is None:
print(f"Warning: Template not found: {script_name}")
continue
target_file = current_dir / script_name
shutil.copy2(source_path, target_file)
copied_files.append(script_name)
print(f"Copied: {script_name}")
if copied_files:
print(f"Successfully copied {len(copied_files)} customizable script(s)")
print("You can now modify these files (e.g., switch vLLM/sglang in pdf_to_qa_pipeline.py)")
return True
else:
print("No customizable scripts were copied")
return False
except Exception as e:
print(f"Failed to copy scripts: {e}")
return False
def create_train_config_yaml(cache_path="./", model_name_or_path="Qwen/Qwen2.5-7B-Instruct"):
"""Create train_config.yaml file using built-in LlamaFactory configuration"""
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
caller_cwd = Path(os.environ.get('PWD', os.getcwd()))
cache_path_obj = caller_cwd / cache_path_obj
# 生成时间戳
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir_name = f"pdf2model_cache_{timestamp}" # 改为pdf2model_cache前缀
cache_dir = cache_path_obj / ".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
config_file = cache_dir / "train_config.yaml"
try:
# 使用内置的 LlamaFactory.py 获取默认配置
llamafactory_script_path = get_dataflow_script_path("llama_factory_trainer.py")
if llamafactory_script_path is None:
print("Built-in llama_factory_trainer.py not found")
return None
import importlib.util
spec = importlib.util.spec_from_file_location("llamafactory_trainer", llamafactory_script_path)
llamafactory_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(llamafactory_module)
# 创建trainer实例并获取默认配置
trainer = llamafactory_module.LlamaFactoryTrainer(str(config_file), str(cache_path_obj))
config = trainer.get_default_config()
# 只更新必要的动态参数
config["model_name_or_path"] = model_name_or_path
config["output_dir"] = str(cache_path_obj / ".cache" / "saves" / model_dir_name)
config["dataset_dir"] = str(cache_path_obj / ".cache" / "data")
# 根据模型类型设置模板
if "qwen" in model_name_or_path.lower():
config["template"] = "qwen"
elif "llama" in model_name_or_path.lower():
config["template"] = "llama3"
elif "chatglm" in model_name_or_path.lower():
config["template"] = "chatglm3"
elif "baichuan" in model_name_or_path.lower():
config["template"] = "baichuan2"
# 保存配置
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(config, f,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
indent=2)
print(f"train_config.yaml created: {config_file}")
print(f"Model will be saved to: {model_dir_name}")
return str(config_file)
except Exception as e:
print(f"Failed to create train_config.yaml: {e}")
return None
def verify_environment():
"""Verify runtime environment"""
print("Checking environment...")
missing_deps = []
try:
import llamafactory
print("✅ LlamaFactory installed")
except ImportError:
missing_deps.append("llamafactory[torch,metrics]")
try:
import yaml
print("✅ PyYAML installed")
except ImportError:
missing_deps.append("pyyaml")
if missing_deps:
print(f"❌ Missing dependencies: {', '.join(missing_deps)}")
print(f"Install with: pip install {' '.join(missing_deps)}")
return False
return True
def check_required_files():
"""Check if required built-in scripts exist"""
# 检查所有需要的内置脚本
required_scripts = [
"path_to_jsonl_script.py",
"llama_factory_trainer.py"
]
missing_scripts = []
for script in required_scripts:
script_path = get_dataflow_script_path(script)
if script_path is None:
missing_scripts.append(script)
else:
print(f"✅ Found built-in script: {script}")
if missing_scripts:
print(f"❌ Missing built-in scripts: {', '.join(missing_scripts)}")
print("These should be part of the dataflow installation")
return False
# 检查用户目录下是否有可自定义的脚本
current_dir = Path(os.getcwd())
customizable_script = current_dir / "pdf_to_qa_pipeline.py"
if customizable_script.exists():
print("✅ Found customizable script: pdf_to_qa_pipeline.py")
else:
print("❌ Missing customizable script: pdf_to_qa_pipeline.py")
print("Run 'dataflow pdf2model init' first")
return False
return True
def cli_pdf2model_init(cache_path: str = "./", model_name: str = "Qwen/Qwen2.5-7B-Instruct") -> bool:
"""
PDF2Model initialization:
0. Copy only customizable scripts to current directory
1. Create train_config.yaml in .cache directory
"""
print("Starting PDF2Model initialization...")
print(f"Cache directory: {cache_path}")
print(f"Model: {model_name}")
print(f"Output directory: pdf2model_cache_<timestamp>") # 更新输出目录显示
print("-" * 60)
if not verify_environment():
return False
try:
# Step 0: Copy only customizable scripts
if not copy_customizable_scripts():
return False
# Step 1: Create training configuration
print("Step 1: Creating training configuration...")
config_file = create_train_config_yaml(cache_path, model_name)
if config_file:
print("PDF2Model initialization completed!")
return True
else:
print("Failed to create training configuration")
return False
except Exception as e:
print(f"Initialization failed: {e}")
return False
def get_latest_model_dir(cache_path_obj):
"""获取最新的模型目录(基于时间戳)"""
saves_dir = cache_path_obj / ".cache" / "saves"
if not saves_dir.exists():
return None
# 查找所有 pdf2model_cache_ 开头的目录
model_dirs = []
for dir_path in saves_dir.iterdir():
if dir_path.is_dir() and dir_path.name.startswith('pdf2model_cache_'):
# 检查是否包含正确的时间戳格式 (YYYYMMDD_HHMMSS)
timestamp_part = dir_path.name.replace('pdf2model_cache_', '')
if len(timestamp_part) == 15 and timestamp_part[8] == '_':
date_part = timestamp_part[:8]
time_part = timestamp_part[9:]
if date_part.isdigit() and time_part.isdigit() and len(time_part) == 6:
model_dirs.append(dir_path)
if not model_dirs:
return None
# 按名称排序(时间戳会自然排序)
model_dirs.sort(key=lambda x: x.name, reverse=True)
return model_dirs[0]
def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: str = "./") -> bool:
"""
Start PDF2Model training using mix of built-in and user scripts
"""
print("Starting PDF2Model training...")
current_dir = Path(os.getcwd())
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
cache_path_obj = current_dir / cache_path_obj
config_path_obj = Path(lf_yaml)
if not config_path_obj.is_absolute():
config_path_obj = current_dir / config_path_obj
if not verify_environment():
return False
if not check_required_files():
return False
if not config_path_obj.exists():
print(f"Training config file not found: {config_path_obj}")
print(f"{Style.BRIGHT}Run 'dataflow pdf2model init' first")
return False
print("-" * 60)
try:
# Step 1: PDF Detection
script1_path = get_dataflow_script_path("path_to_jsonl_script.py")
args1 = ["./", "--output", str(cache_path_obj / ".cache" / "gpu" / "pdf_list.jsonl")]
if not run_script_with_args(script1_path, "Step 1: PDF Detection", args1, cwd=str(current_dir)):
return False
# Step 2: Data Processing
script2 = current_dir / "pdf_to_qa_pipeline.py"
args2 = ["--cache", cache_path]
if not run_script_with_args(script2, "Step 2: Data Processing", args2, cwd=str(current_dir)):
return False
# Step 2.5: Create dataset_info.json (dynamically)
print(f"\n{Fore.BLUE}Step 2.5: Creating dataset_info.json{Style.RESET_ALL}")
# 读取训练配置,获取数据集名称
try:
with open(config_path_obj, 'r', encoding='utf-8') as f:
train_config = yaml.safe_load(f)
# 获取数据集名称
dataset_name = train_config.get('dataset')
if isinstance(dataset_name, list):
dataset_name = dataset_name[0] # 如果是列表,取第一个
if not dataset_name:
print("Warning: No dataset name found in train_config.yaml, using default 'kb_qa'")
dataset_name = 'kb_qa'
print(f"Dataset name from config: {dataset_name}")
except Exception as e:
print(f"Warning: Could not read train_config.yaml: {e}")
print("Using default dataset name: kb_qa")
dataset_name = 'kb_qa'
# 创建 dataset_info.json
dataset_info_path = cache_path_obj / ".cache" / "data" / "dataset_info.json"
dataset_info_path.parent.mkdir(parents=True, exist_ok=True)
dataset_info = {
dataset_name: { # ← 使用从配置读取的名称
"file_name": "qa.json",
"formatting": "alpaca",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output"
}
}
}
with open(dataset_info_path, 'w', encoding='utf-8') as f:
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
print(f"Created: {dataset_info_path}")
print(f"Dataset registered as: {dataset_name}")
print(f"{Fore.GREEN}✅ Step 2.5: Creating dataset_info.json completed{Style.RESET_ALL}")
# Step 3: Data Conversion - skip
print(f"\n{Fore.BLUE}Step 3: Data Conversion{Style.RESET_ALL}")
qa_json_path = cache_path_obj / ".cache" / "data" / "qa.json"
if qa_json_path.exists():
print(f"✅ qa.json already in correct format, skipping conversion")
print(f"{Fore.GREEN}✅ Step 3: Data Conversion completed{Style.RESET_ALL}")
else:
print(f"❌ qa.json not found at {qa_json_path}")
return False
# Step 4: Training
script4_path = get_dataflow_script_path("llama_factory_trainer.py")
args4 = ["--config", str(config_path_obj), "--cache", cache_path]
if not run_script_with_args(script4_path, "Step 4: Training", args4, cwd=str(current_dir)):
return False
# Show completion info
try:
with open(config_path_obj, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
actual_output_dir = config.get('output_dir', 'unknown')
except:
actual_output_dir = 'unknown'
print("Training completed successfully!")
print(f"Model saved to: {actual_output_dir}")
print("Next steps:")
print(f"{Style.BRIGHT}Test the trained model with 'dataflow chat'")
return True
except Exception as e:
print(f"Training error: {e}")
return False
def cli_pdf2model_chat(model_path=None, cache_path="./", base_model=None):
"""Start LlamaFactory chat interface"""
current_dir = Path(os.getcwd())
# 处理cache路径
cache_path_obj = Path(cache_path)
if not cache_path_obj.is_absolute():
cache_path_obj = current_dir / cache_path_obj
# 确定模型路径
if model_path is None:
# 获取最新的模型目录
latest_model_dir = get_latest_model_dir(cache_path_obj)
if latest_model_dir:
model_path = latest_model_dir
else:
print("No trained model found")
print("Run 'dataflow pdf2model train' to train a model first")
return False
model_path = Path(model_path)
if not model_path.exists():
print(f"Model not found: {model_path}")
print("Run 'dataflow pdf2model train' to train a model first")
return False
# 验证是否为有效的adapter目录
adapter_files = [
"adapter_config.json",
"adapter_model.bin",
"adapter_model.safetensors"
]
has_adapter = any((model_path / f).exists() for f in adapter_files)
if not has_adapter:
print(f"No adapter files found in {model_path}")
print("This doesn't appear to be a trained adapter directory.")
print("Expected files: adapter_config.json, adapter_model.bin/safetensors")
return False
# 确定基础模型路径 - 安全的读取方式
if base_model is None:
base_model = None # 先设为None
# 尝试从训练配置中读取基础模型
config_file = cache_path_obj / ".cache" / "train_config.yaml"
if config_file.exists():
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
base_model = config.get('model_name_or_path')
if base_model:
print(f"Found base model in config: {base_model}")
except Exception as e:
print(f"Warning: Could not read config file: {e}")
# 尝试从adapter_config.json读取
if not base_model:
adapter_config_path = model_path / "adapter_config.json"
if adapter_config_path.exists():
try:
with open(adapter_config_path, 'r', encoding='utf-8') as f:
adapter_config = json.load(f)
base_model = adapter_config.get('base_model_name_or_path')
if base_model:
print(f"Found base model in adapter config: {base_model}")
except Exception as e:
print(f"Warning: Could not read adapter config: {e}")
# 如果仍然没有找到base_model,报错退出而不是使用默认值
if not base_model:
print("Cannot determine base model path")
print("Please ensure your training config contains 'model_name_or_path'")
print("Or check that adapter_config.json exists and contains 'base_model_name_or_path'")
return False
# 检查LlamaFactory
try:
import llamafactory
print("LlamaFactory available")
except ImportError:
print("LlamaFactory not installed")
print("Install with: pip install llamafactory[torch,metrics]")
return False
# 直接用命令行参数启动聊天
chat_cmd = [
"llamafactory-cli", "chat",
"--model_name_or_path", base_model,
"--adapter_name_or_path", str(model_path.absolute())
]
print(f"Base model: {base_model}")
print(f"Adapter path: {model_path}")
print(f"Command: {' '.join(chat_cmd)}")
print("-" * 60)
print("Starting chat session...")
print("-" * 60)
try:
result = subprocess.run(chat_cmd, check=True)
print("\nChat session completed")
return True
except subprocess.CalledProcessError as e:
print(f"\nChat failed: {e}")
return False
except KeyboardInterrupt:
print("\n\nChat session ended by user")
return True
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment