Commit c0d96b32 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2852 failed with stages
in 0 seconds
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [OpenBMB]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
This diff is collapsed.
# MiniCPM4
速度狂飙,快至220倍!MiniCPM4.0-8B是首个原生稀疏模型,5%的极高稀疏度加持系统级创新技术的大爆发,宣告了端侧长文本时代到来!
## 论文
`MiniCPM4: Ultra-Efficient LLMs on End Devices`
- https://arxiv.org/pdf/2506.07900
## 模型结构
MiniCPM4核心架构基于Transformer Decoder-only,引入InfLLM 2.0混合稀疏注意力结构,采用「高效双频换挡」机制,能够根据任务特征自动切换注意力模式:在处理高难度的长文本、深度思考任务时,启用稀疏注意力以降低计算复杂度,在短文本场景下切换至稠密注意力以确保精度,实现了长、短文本切换的高效响应。
<div align=center>
<img src="./doc/structure.png"/>
</div>
## 算法原理
MiniCPM 4.0模型采用的InfLLMv2稀疏注意力架构改变了传统Transformer模型的相关性计算方式:对分块分区域高效「抽查」,即对文本进行分块分区域处理后,通过智能化选择机制,只需对最有相关性的重点区域进行注意力计算“抽查”,摆脱了逐字重复计算的低效,注意力层仅需1/10的计算量即可完成长文本计算。
<div align=center>
<img src="./doc/Sparse_Attention.png"/>
</div>
## 环境配置
```
mv MiniCPM4_pytorch MiniCPM4
```
### 硬件需求
DCU型号:K100AI,节点数量:1 台,卡数:4 张。
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04.1-py3.10
# <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:e50d644287fd
docker run -it --shm-size=64G -v $PWD/MiniCPM4:/home/MiniCPM4 -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name minicpm4 <your IMAGE ID> bash
cd /home/MiniCPM4
pip install -r requirements.txt # requirements.txt
```
### Dockerfile(方法二)
```
cd /home/MiniCPM4/docker
docker build --no-cache -t minicpm4:latest .
docker run --shm-size=64G --name minicpm4 -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v $PWD/../../MiniCPM4:/home/MiniCPM4 -it minicpm4 bash
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r requirements.txt。
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
- https://developer.sourcefind.cn/tool/
```
DTK驱动:25.04.1
python:python3.10
torch:2.4.1
torchvision:0.19.1
triton:3.0.0
flash-attn:2.6.1
deepspeed:0.14.2
apex:1.4.0
transformers:4.53.2
```
不同深度学习库可支持的DCU型号可在此处查询:[DAS资源下载](https://das.sourcefind.cn:55011/portal/#/home)
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应。`
2、其它非特殊库参照requirements.txt安装
```
cd /home/MiniCPM4
pip install -r requirements.txt # requirements.txt
```
报错解决:
```
1、TypeError: Phi3LongRoPEScaledRotaryEmbedding._compute_cos_sin_cache() missing 3 required positional arguments: 'max_position_embeddings', 'rescale_factors', and 'mscale'
File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/minicpm.py", line 245
下面这一行代码是不必要的,因为RotaryEmbedding构造函数已经处理了适当的缓存初始化,注释掉:
# self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
# )
2、ValueError: You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.
openbmb/MiniCPM4-8B/modeling_minicpm.py", line 2052
将下面代码:
if use_legacy_cache:
raise ValueError(
'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
替换成:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
```
## 数据集
`项目中官方提供了用于试验的示例数据集`
```
/home/MiniCPM4/finetune/data/
├── AdvertiseGenChatML
| ├── train.json
| └── dev.json
└── ocnli_public_chatml
├── train.json
└── dev.json
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)
## 训练
### 单机多卡
预训练权重目录结构:
```
/home/MiniCPM4/
└── openbmb/MiniCPM4-8B
```
```
cd /home/MiniCPM4/finetune
bash lora_finetune_minicpm4.sh # 此处以MiniCPM4-8B为例,其它参数量的模型以此类推。
```
## 推理
### 单机单卡
```
cd /home/MiniCPM4
# 方法一:transformers推理
python infer_transformers.py
# 方法二:vllm推理
python infer_vllm.py # 官方版vLLM目前不支持InfLLM-v2。
# 目前已开放dense推理,投机采样、量化、量化加投机敬请期待后续vllm的适配优化。
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)
## result
此处以vllm版的推理结果示例:
`输入: `
```
推荐5个北京的景点。
```
`输出:`
```
北京,这座历史悠久、文化底蕴深厚的城市,拥有众多令人向往的景点。以下是五个不容错过的北京景点推荐:
1. **故宫博物院**:作为明清两代皇家宫殿,故宫不仅是世界上最大的木质结构建筑群,也是中国乃至世界上最大的古代宫廷博物馆。这里收藏着大量的珍贵文物,如书画、瓷器、玉器等,能够让人近距离感受到中国传统文化的魅力。
2. **长城**:作为中华民族的象征,长城是中国古代军事防御工程的杰出代表。其中,八达岭长城是最为著名的一段,其地势险峻,长城蜿蜒曲折,是游客体验长城雄伟壮观的最佳地点。
3. **天安门广场**:作为世界上最大的城市中心广场,天安门广场不仅是国家的重要政治活动场所,也是游客们欣赏宏伟建筑的好去处。广场上的天安门城楼、人民英雄纪念碑、毛主席纪念堂等,都是历史的见证。
4. **颐和园**:颐和园是中国保存最完整的皇家园林,以其精美的园林艺术和丰富的文化内涵而著称。园内的昆明湖、万寿山、长廊等景点,让人仿佛置身于一幅生动的中国山水画中。
5. **圆明园**:虽然历经劫难,但圆明园的残垣断壁依然透露出清朝皇家园林的辉煌。园内的荷花池、西洋楼遗址等,让人在感叹历史的同时,也能感受到中国园林艺术的精妙。
以上五个景点,不仅能够让人领略到北京深厚的历史文化底蕴,也是来京旅游者必访之地。
```
### 精度
DCU与GPU精度一致,推理框架:vllm,训练中所用数据为少量demo数据,仅供模型训练方法测试,故无法作为训练精度参考。
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`制造,广媒,金融,能源,医疗,家居,教育`
## 预训练权重
魔搭社区下载地址为:[OpenBMB/MiniCPM4-8B](https://www.modelscope.cn/models/OpenBMB/MiniCPM4-8B)
## 源码仓库及问题反馈
- http://developer.sourcefind.cn/codes/modelzoo/MiniCPM4_pytorch.git
## 参考资料
- https://github.com/OpenBMB/MiniCPM.git
This diff is collapsed.
from typing import List
import argparse
import gradio as gr
import torch
from threading import Thread
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-2B-dpo-fp16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="cuda:0", trust_remote_code=True)
model_architectures = model.config.architectures[0]
def check_model_v(img_file_path: str = None):
'''
check model is MiniCPMV
Args:
img_file_path (str): Image filepath
Returns:
Ture if model is MiniCPMV else False
'''
if "MiniCPMV" in model_architectures:
return True
if isinstance(img_file_path, str):
gr.Warning('Only MiniCPMV model can support Image')
return False
if check_model_v():
model = model.to(dtype=torch.bfloat16)
# init gradio demo host and port
server_name = args.server_name
server_port = args.server_port
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
str: real-time generation results of hf model
"""
inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
enc = tokenizer(inputs, return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(
enc,
do_sample=True,
top_k=0,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_new_tokens=max_dec_len,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
answer = ""
for new_text in streamer:
answer += new_text
yield answer[4 + len(inputs):]
def hf_v_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
str: real-time generation results of hf model
"""
assert isinstance(img_file_path, str), 'Image must not be empty'
img = Image.open(img_file_path).convert('RGB')
generation_kwargs = dict(
image=img,
msgs=dialog,
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
max_new_tokens=max_dec_len
)
res, context, _ = model.chat(**generation_kwargs)
return res
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str = None):
"""generate after hitting "submit" button
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
query (str): query of current round
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
for q, a in chat_history:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": query})
# yield model generation
chat_history.append([query, ""])
if check_model_v():
chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path)
yield gr.update(value=""), chat_history
return
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history
def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str = None):
"""re-generate the answer of last round's query
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
for q, a in chat_history[:-1]:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": chat_history[-1][0]})
# yield model generation
if check_model_v():
chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path)
yield gr.update(value=""), chat_history
return
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history
def clear_history():
"""clear all chat history
Returns:
List: empty chat history
"""
return []
def reverse_last_round(chat_history):
"""reverse last round QA and keep the chat history before
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
"""
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1]
# launch gradio demo
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""# MiniCPM Gradio Demo""")
with gr.Row():
with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
img_file_path = gr.Image(label="upload image", type='filepath', show_label=False)
with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse")
img_file_path.change(check_model_v, inputs=[img_file_path], outputs=[])
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty,
max_dec_len, img_file_path], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty,
max_dec_len, img_file_path], outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
demo.queue()
demo.launch(server_name=server_name, server_port=server_port, show_error=True)
"""
my package: langchain_demo
langchain 0.2.6
langchain-community 0.2.1
langchain-core 0.2.19
langchain-text-splitters 0.2.0
langchainplus-sdk 0.0.20
pypdf 4.3.0
pydantic 2.8.2
pydantic_core 2.20.1
transformers 4.41.1
triton 2.3.0
trl 0.8.6
vllm 0.5.0.post1+cu122
vllm-flash-attn 2.5.9
vllm_nccl_cu12 2.18.1.0.4.0
你只需要最少6g显存(足够)的显卡就能在消费级显卡上体验流畅的rag。
使用方法:
1. 运行pull_request/rag/langchain_demo.py
2. 上传pdf/txt文件(同一目录下可传多个)
3. 输入问题。
极低显存(4g)使用方法:
1. 根据MiniCPM/quantize/readme.md进行量化,推荐量化MiniCPM-1B-sft-bf16
2. 将cpm_model_path修改为量化后模型地址
3. 保证encode_model_device设置为cpu
"""
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
from argparse import ArgumentParser
from langchain.llms.base import LLM
from typing import Any, List, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from langchain.prompts import PromptTemplate
from pydantic.v1 import Field
import re
import gradio as gr
parser = ArgumentParser()
# 大语言模型参数设置
parser.add_argument(
"--cpm_model_path",
type=str,
default="openbmb/MiniCPM-1B-sft-bf16",
help="MiniCPM模型路径或者huggingface id"
)
parser.add_argument(
"--cpm_device", type=str, default="cuda:0", choices=["auto", "cuda:0"],
help="MiniCPM模型所在设备,默认为cuda:0"
)
parser.add_argument("--backend", type=str, default="torch", choices=["torch", "vllm"],
help="使用torch还是vllm后端,默认为torch"
)
# 嵌入模型参数设置
parser.add_argument(
"--encode_model", type=str, default="BAAI/bge-base-zh",
help="用于召回编码的embedding模型,默认为BAAI/bge-base-zh,可输入本地地址"
)
parser.add_argument(
"--encode_model_device", type=str, default="cpu", choices=["cpu", "cuda:0"],
help="用于召回编码的embedding模型所在设备,默认为cpu"
)
parser.add_argument("--query_instruction", type=str, default="",help="召回时增加的前缀")
parser.add_argument(
"--file_path", type=str, default="/root/ld/pull_request/rag/红楼梦.pdf",
help="需要检索的文本文件路径,gradio运行时无效"
)
# 生成参数
parser.add_argument("--top_k", type=int, default=3)
parser.add_argument("--top_p", type=float, default=0.7)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max_new_tokens", type=int, default=4096)
parser.add_argument("--repetition_penalty", type=float, default=1.02)
# retriever参数设置
parser.add_argument("--embed_top_k", type=int, default=5,help="召回几个最相似的文本")
parser.add_argument("--chunk_size", type=int, default=256,help="文本切分时切分的长度")
parser.add_argument("--chunk_overlap", type=int, default=50,help="文本切分的重叠长度")
args = parser.parse_args()
def clean_text(text):
"""
清理文本,去除中英文字符、数字及常见标点。
参数:
text (str): 需要清理的原始文本。
返回:
str: 清理后的文本。
"""
# 定义需要去除的字符模式:中文、英文、数字、常见标点
pattern = r'[\u4e00-\u9fa5]|[A-Za-z0-9]|[.,;!?()"\']'
# 使用正则表达式替换这些字符为空字符串
cleaned_text = re.sub(pattern, "", text)
# 去除多余的空格
cleaned_text = re.sub(r"\s+", " ", cleaned_text)
return cleaned_text
class MiniCPM_LLM(LLM):
tokenizer: Any = Field(default=None)
model: Any = Field(default=None)
def __init__(self, model_path: str):
"""
继承langchain的MiniCPM模型
参数:
model_path (str): 需要加载的MiniCPM模型路径。
返回:
self.model: 加载的MiniCPM模型。
self.tokenizer: 加载的MiniCPM模型的tokenizer。
"""
super().__init__()
if args.backend == "vllm":
from vllm import LLM
self.model = LLM(
model=model_path, trust_remote_code=True, enforce_eager=True
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, torch_dtype=torch.float16
).to(args.cpm_device)
self.model = self.model.eval()
def _call(self, prompt, stop: Optional[List[str]] = None):
"""
langchain.llm的调用
参数:
prompt (str): 传入的prompt文本
返回:
responds (str): 模型在prompt下生成的文本
"""
if args.backend == "torch":
inputs = self.tokenizer("<用户>{}".format(prompt), return_tensors="pt")
inputs = inputs.to(args.cpm_device)
# Generate
generate_ids = self.model.generate(
inputs.input_ids,
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
responds = self.tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# responds, history = self.model.chat(self.tokenizer, prompt, temperature=args.temperature, top_p=args.top_p, repetition_penalty=1.02)
else:
from vllm import SamplingParams
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": args.repetition_penalty,
"frequency_penalty": 0.0,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop": None,
"stop_token_ids": None,
"ignore_eos": False,
"max_tokens": args.max_new_tokens,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
prompt = "<用户>{}<AI>".format(prompt)
responds = self.model.generate(prompt, sampling_params)
responds = responds[0].outputs[0].text
return responds
@property
def _llm_type(self) -> str:
return "MiniCPM_LLM"
# 加载PDF和TXT文件
def load_documents(file_paths):
"""
加载文本和pdf文件中的字符串,并进行简单的清洗
参数:
file_paths (str or list): 传入的文件地址或者文件列表
返回:
documents (list): 读取的文本列表
"""
files_list = []
if type(file_paths) == list:
files_list = file_paths
else:
files_list = [file_paths]
documents = []
for file_path in files_list:
if file_path.endswith(".pdf"):
loader = PyPDFLoader(file_path)
elif file_path.endswith(".txt"):
loader = TextLoader(file_path)
else:
raise ValueError("Unsupported file type")
doc = loader.load()
doc[0].page_content = clean_text(doc[0].page_content)
documents.extend(doc)
return documents
def load_models():
"""
加载模型和embedding模型
返回:
llm: MiniCPM模型
embedding_models: embedding模型
"""
llm = MiniCPM_LLM(model_path=args.cpm_model_path)
embedding_models = HuggingFaceBgeEmbeddings(
model_name=args.encode_model,
model_kwargs={"device": args.encode_model_device}, # 或者 'cuda' 如果你有GPU
encode_kwargs={
"normalize_embeddings": True, # 是否归一化嵌入
"show_progress_bar": True, # 是否显示进度条
"convert_to_numpy": True, # 是否将输出转换为numpy数组
"batch_size": 8, # 批处理大小'
},
query_instruction=args.query_instruction,
)
return llm, embedding_models
# 分割并嵌入文档
def embed_documents(documents, embedding_models):
"""
对文档进行分割和嵌入
参数:
documents (list): 读取的文本列表
embedding_models: embedding模型
返回:
vectorstore:向量数据库
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
texts = text_splitter.split_documents(documents)
vectorstore = Chroma.from_documents(texts, embedding_models)
return vectorstore
def create_prompt_template():
"""
创建自定义的prompt模板
返回:
PROMPT:自定义的prompt模板
"""
custom_prompt_template = """请使用以下内容片段对问题进行最终回复,如果内容中没有提到的信息不要瞎猜,严格按照内容进行回答,不要编造答案,如果无法从内容中找到答案,请回答“片段中未提及,无法回答”,不要编造答案。
Context:
{context}
Question: {question}
FINAL ANSWER:"""
PROMPT = PromptTemplate(
template=custom_prompt_template, input_variables=["context", "question"]
)
return PROMPT
# 创建RAG链
def create_rag_chain(llm, prompt):
# qa=load_qa_with_sources_chain(llm, chain_type="stuff")
qa = prompt | llm
return qa
def analysis_links(docs):
"""
分析链接
参数:
docs (list): 读取的文本列表
返回:
links_string:相关文档引用字符串,docname page content
示例:
>>> docs = [
... {'source': 'Document1', 'page': 1, 'content': 'This is the first document.'},
... {'source': 'Document2', 'page': 2, 'content': 'This is the second document.'}
... ]
>>> extract_links(docs)
'Document1 page:1 \n\nThis is the first document.\nDocument2 page:2 \n\nThis is the second document.'
"""
links_string = ""
for i in docs:
i.metadata["source"] = i.metadata["source"].split("/")[-1]
i.metadata["content"] = i.page_content
links_string += f"{i.metadata['source']} page:{i.metadata['page']}\n\n{i.metadata['content']}\n\n"
return links_string
# 主函数
def main():
# 加载文档
documents = load_documents(args.file_path)
# 嵌入文档
vectorstore = embed_documents(documents, embedding_models)
# 自建prompt模版
Prompt = create_prompt_template()
# 创建RAG链
rag_chain = create_rag_chain(llm, Prompt)
# 用户查询
while True:
query = input("请输入查询:")
if query == "exit":
break
docs = vectorstore.similarity_search(query, k=args.embed_top_k)
all_links = analysis_links(docs)
final_result = rag_chain.invoke({"context": all_links, "question": query})
# result = rag_chain({"input_documents": docs, "question": query}, return_only_outputs=True)
print(final_result)
exist_file = None
def process_query(file, query):
global exist_file, documents, vectorstore, rag_chain
if file != exist_file:
# 加载文档
documents = load_documents(file if isinstance(file, list) else file.name)
# 嵌入文档
vectorstore = embed_documents(documents, embedding_models)
# 自建prompt模版
Prompt = create_prompt_template()
# 创建RAG链
rag_chain = create_rag_chain(llm, Prompt)
exist_file = file
# 搜索并获取结果
docs = vectorstore.similarity_search(query, k=args.embed_top_k)
all_links = analysis_links(docs)
final_result = rag_chain.invoke({"context": all_links, "question": query})
# result = rag_chain({"input_documents": docs, "question": query}, return_only_outputs=False)
print(final_result)
final_result = final_result.split("FINAL ANSWER:")[-1]
return final_result, all_links
if __name__ == "__main__":
llm, embedding_models = load_models()
# 如果不需要web界面可以直接运行main函数
#main()
with gr.Blocks(css="#textbox { height: 380%; }") as demo:
with gr.Row():
with gr.Column():
link_content = gr.Textbox(label="link_content", lines=30, max_lines=40)
with gr.Column():
file_input = gr.File(label="upload_files", file_count="multiple")
final_anser = gr.Textbox(label="final_anser", lines=5, max_lines=10)
query_input = gr.Textbox(
label="User",
placeholder="Input your query here!",
lines=5,
max_lines=10,
)
submit_button = gr.Button("Submit")
submit_button.click(
fn=process_query,
inputs=[file_input, query_input],
outputs=[final_anser, link_content],
)
demo.launch(share=True, show_error=True)
"""
使用 MLX 快速推理 MiniCPM
如果你使用 Mac 设备进行推理,可以直接使用MLX进行推理。
由于 MiniCPM 暂时不支持 mlx 格式转换。您可以下载由 MLX 社群转换好的模型 [MiniCPM-2B-sft-bf16-llama-format-mlx](https://huggingface.co/mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx)。
并安装对应的依赖包
```bash
pip install mlx-lm
```
这是一个简单的推理代码,使用 Mac 设备推理 MiniCPM-2
```python
python -m mlx_lm.generate --model mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx --prompt "hello, tell me a joke." --trust-remote-code
```
"""
from mlx_lm import load, generate
from jinja2 import Template
def chat_with_model():
model, tokenizer = load("mlx-community/MiniCPM-2B-sft-bf16-llama-format-mlx")
print("Model loaded. Start chatting! (Type 'quit' to stop)")
messages = []
chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}")
while True:
user_input = input("You: ")
if user_input.lower() == 'quit':
break
messages.append({"role": "user", "content": user_input})
response = generate(model, tokenizer, prompt=chat_template.render(messages=messages), verbose=True)
print("Model:", response)
messages.append({"role": "ai", "content": response})
chat_with_model()
from typing import List
import argparse
import gradio as gr
from vllm import LLM, SamplingParams
import torch
from transformers import AutoTokenizer
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
parser.add_argument("--max_tokens", type=int, default=2048)
# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048
args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer
path = args.model_path
llm = LLM(
model=path,
tensor_parallel_size=1,
dtype=torch_dtype,
trust_remote_code=True,
gpu_memory_utilization=0.9,
max_model_len=args.max_tokens
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
server_name = args.server_name
server_port = args.server_port
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
str: real-time generation results of hf model
"""
assert len(dialog) % 2 == 1
prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
token_ids = tokenizer.convert_tokens_to_ids(["<|im_end|>"])
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop": "<|im_end|>",
"stop_token_ids": token_ids,
"ignore_eos": False,
"max_tokens": max_dec_len,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)[0]
generated_text = outputs.outputs[0].text
return generated_text
def generate(chat_history: List, query: str, top_p: float, temperature: float, max_dec_len: int):
"""generate after hitting "submit" button
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
query (str): query of current round
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
for q, a in chat_history:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": query})
# yield model generation
model_output = vllm_gen(model_input, top_p, temperature, max_dec_len)
chat_history.append([query, model_output])
return gr.update(value=""), chat_history
def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len: int):
"""re-generate the answer of last round's query
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
for q, a in chat_history[:-1]:
model_input.append({"role": "user", "content": q})
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": chat_history[-1][0]})
# yield model generation
model_output = vllm_gen(model_input, top_p, temperature, max_dec_len)
chat_history[-1][1] = model_output
return gr.update(value=""), chat_history
def clear_history():
"""clear all chat history
Returns:
List: empty chat history
"""
return []
def reverse_last_round(chat_history):
"""reverse last round QA and keep the chat history before
Args:
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records
Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
"""
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1]
# launch gradio demo
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""# MiniCPM Gradio Demo""")
with gr.Row():
with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens")
with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, max_dec_len], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, max_dec_len], outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
demo.queue()
demo.launch(server_name=server_name, server_port=server_port, show_error=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment