Commit 6088e14e authored by chenych's avatar chenych
Browse files

jifu v1.0

parent 2397728d
# chat_demo ## 树博士:一种基于检索增强生成模型的大型智能客服机器人系统
技服智能问答服务
## 环境配置
### Docker(方式一)
-v 路径、docker_name和imageID根据实际情况修改
```bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
# 加载运行环境变量
unzip dtk-cuda.zip -d /opt/dtk/
source /opt/dtk/cuda/env.sh
# 下载fastllm库
git clone http://developer.hpccube.com/codes/OpenDAS/fastllm.git
# 编译fastllm
cd fastllm
mkdir build
cd build
cmake ..
make -j
# 编译完成后,可以使用如下命令安装简易python工具包
cd tools # 这时在fastllm/build/tools目录下
python setup.py install
cd /path/of/chat_demo
pip install faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl
pip install -r requirements.txt
```
### Dockerfile(方式二) “树博士”是一个基于 RAG结合LLM 的领域知识助手。特点:
```
docker build -t chat_demo:latest . 1. 应对垂直领域复杂应用场景,解答用户问题的同时,不会产生“幻觉”
docker run -dit --network=host --name=chat_demo --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 chat_demo:latest 2. 提出一套解答技术问题的算法 pipeline
docker exec -it chat_demo /bin/bash 3. 模块化组合部署成本低,安全可靠鲁棒性强
## 步骤1. 环境配置
- 拉取镜像并创建容器:[光源镜像下载地址](https://sourcefind.cn/#/image/dcu/pytorch?activeName=overview)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
- 运行容器:
docker run -dit --name assitant --privileged --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal -v /path/to/model:/opt/model:ro image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310
注意替换-v参数,即宿主机模型存放位置
- 安装
```shell
git clone http://10.6.10.68/aps/ai.git
cd ai
# 如果是centos就用yum
apt update
apt install python-dev libxml2-dev libxslt1-dev antiword unrtf poppler-utils pstotext tesseract-ocr flac ffmpeg lame libmad0 libsox-fmt-mp3 sox libjpeg-dev swig libpulse-dev
下载并安装dtk所需其他包:[faiss](http://10.6.10.68:8000/release/faiss/dtk24.04/faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl)
wget http://10.6.10.68:8000/release/faiss/dtk24.04/faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl
pip install faiss-1.7.2_dtk22.10_gitb7348e7df780-py3-none-any.whl
安装火狐浏览器和驱动并配置环境变量
sudo apt-get update
sudo apt-get install firefox # 安装 Firefox 浏览器
wget https://github.com/mozilla/geckodriver/releases/latest/download/geckodriver-v0.34.0-linux64.tar.gz # 下载 geckodriver
tar -xvzf geckodriver-v0.34.0-linux64.tar.gz # 解压 geckodriver
sudo mv geckodriver /usr/local/bin/ # 将 geckodriver 移动到 /usr/local/bin 目录下
nano ~/.bashrc #配置环境变量,编辑 ~/.bashrc 文件
export PATH=$PATH:/usr/local/bin
export GECKODRIVER=/usr/local/bin/geckodriver # 添加以上两行内容
source ~/.bashrc # 保存并退出编辑器,然后让更改立即生效:
echo $GECKODRIVER # 验证环境变量
chmod +x /usr/local/bin/geckodriver #常见问题排查 权限问题:确保 geckodriver 有执行权限。如果没有,使用 chmod +x /usr/local/bin/geckodriver 赋予执行权限。
安装程序依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 步骤2. 准备模型
# 其他步骤同上面的Docker(方式一) 首次运行树博士需手动下载相关模型到本地,下载地址:
[BCERerank下载地址](https://modelscope.cn/models/maidalun/bce-reranker-base_v1/files)
[Text2vec-large-chinese下载地址](https://modelscope.cn/models/Jerry0/text2vec-large-chinese/files)
[Llama3-8B-Chinese-Chat下载地址](https://hf-mirror.com/shenzhi-wang/Llama3-8B-Chinese-Chat)
bert-finetune-dcu模型为微调模型,可以联系开发获取
注意下载位置要与步骤一中的-v参数一致
## 步骤3. 修改配置文件:
ai/config.ini
```shell
[default]
work_dir = /path/to/your/ai/work_dir #填写ai/work_dir的绝对路径
bind_port = 8000 #填写服务对外暴露的他很害怕端口
summarize_query = False #查询前是否使用对问题进行总结
use_rag = True #是否使用RAG查询
use_template = False #是否使用模板输出RAG结果
output_format = True #是否使用Markdown格式化输出结果
filter_illegal = True #是否检查敏感问题
[feature_database]
file_processor = 10 #文件解析线程数
[llm]
llm_service_address = http://127.0.0.1:8001 #通用模型访问地址
llm_model = /path/to/your/Llama3-8B-Chinese-Chat/ #通用模型名称
cls_service_address = http://127.0.0.1:8002 #分类模型访问地址
rag_service_address = http://127.0.0.1:8003 #RAG服务访问地址
max_input_length = 1400 #输入长度限制
``` ```
### Conda(方法三) ai/rag/config.ini
关于本项目DCU显卡所需的工具包、深度学习库等均可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。 ```shell
[default]
work_dir = /path/to/your/ai/work_dir #填写ai/work_dir的绝对路径
bind_port = 8003 #填写服务对外暴露的他很害怕端口
```bash [rag]
DTK驱动: dtk24.04 embedding_model_path = /path/to/your/text2vec-large-chinese #填写text2vec-large-chinese模型的目录所在绝对路径
python: python3.10 reranker_model_path = /path/to/your/bce-reranker-base_v1 #填写bce-reranker-base_v1模型的目录所在绝对路径
torch: 2.1.0 vector_top_k = 5 #向量库查询数量
es_top_k = 5 #es查询数量
es_url = http://10.2.106.50:31920 #es访问地址
index_name = dcu_knowledge_base #es索引名称
``` ```
`Tips:以上dtk驱动、python、deepspeed等工具版本需要严格一一对应。`
3. 其它依赖库参照requirements.txt安装: 向量库需联系开发人员获取,放置到work_dir下
```shell
ll ai/work_dir/db_response/
total 173696
-rw-r--r-- 1 root root 154079277 Aug 22 10:04 index.faiss
-rw-r--r-- 1 root root 23767932 Aug 22 10:04 index.pkl
``` ```
pip install faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl
pip install -r requirements.txt
## 步骤4. 运行
- 运行llm服务
指定服务运行的dcu卡编号,尽量找空闲卡
export CUDA_VISIBLE_DEVICES='1'
使用vllm运行标准openai接口服务
python -m vllm.entrypoints.openai.api_server --model /opt/model/Llama3-8B-Chinese-Chat/ --enforce-eager --dtype float16 --trust-remote-code --port 8001 > llm.log 2>&1 &
- 运行分类模型
python ai/classify/classify.py --model_path /opt/model/bert-finetune-dcu/ --port 8002 --dcu_id 2 > classify.log 2>&1 &
参数说明:
--model_path 分类模型路径
--port 分类服务端口
--dcu_id 运行dcu卡编号
- 运行RAG服务
python ai/rag/retriever.py --config_path /root/ai/rag/config.ini --dcu_id 3 > rag.log 2>&1 &
提示:前三个服务都需要dcu设备,因此可以后台运行在一个容器中
- 运行助手服务
python ai/server_start.py --config_path /root/ai/config.ini --log_path /var/log/assistant.log
提示:此服务不依赖dcu设备,可以运行在k8s中
- 客户端调用示例:
```shell
python client.py --action query --query "你好, 我们公司想要购买几台测试机, 请问需要联系贵公司哪位?"
..
20xx-xx-xx hh:mm:ss.sss | DEBUG | __main__:<module>:30 -
reply: 您好,您需要联系中科曙光售前咨询平台,您可以通过访问官网,根据您所在地地址联系平台人员,
或者点击人工客服进行咨询,或者拨打中科曙光服务热线400-810-0466联系人工进行咨询。
如果您需要购买机架式服务器,I620--G30,请与平台人员联系。,
ref: ['train.json', 'FAQ_clean.xls', 'train.json'
```
# 🛠️ 未来计划
1. 支持网络检索融合进工作流
2. 更先进的llm微调与支持(llama3)
3. vllm加速
4. 支持多个本地llm及远程llm调用
5. 支持多模态图文问答
# 🛠️ 相关问题
```Could not load library with AVX2 support due to:ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'") ```问题修复:
找到安装faiss位置
```
import faiss
print(faiss.__file__)
# /.../python3.10/site-packages/faiss/__init__.py
```
添加软链接
``` ```
# cd your_python_path/site-packages/faiss
cd /.../python3.10/site-packages/faiss/
ln -s swigfaiss.py swigfaiss_avx2.py
```
from .llm_service import ChatAgent # noqa E401
from .llm_service import ErrorCode # noqa E401 from .llm_service import ErrorCode # noqa E401
from .llm_service import FeatureDataBase # noqa E401 from .llm_service import FeatureDataBase # noqa E401
from .llm_service import LLMInference # noqa E401
from .llm_service import llm_inference # noqa E401
from .llm_service import Worker # noqa E401 from .llm_service import Worker # noqa E401
from .llm_service import DocumentName # noqa E401 from .llm_service import DocumentName # noqa E401
from .llm_service import DocumentProcessor # noqa E401 from .llm_service import DocumentProcessor # noqa E401
\ No newline at end of file
from .llm_service import rag_retrieve # noqa E401
\ No newline at end of file
import json
import requests
import argparse
import re
'''
使用示例:
公共知识库检索:python client.py --action query --query '问题'
私有知识库检索:python client.py --action query --query '问题' --user_id 'user_id'
'''
base_url = 'http://127.0.0.1:8000/%s'
def query(query, user_id=None):
url = base_url % 'work'
try:
header = {'Content-Type': 'application/json'}
# Add history to data
data = {
'query': query,
'history': []
}
if user_id:
data['user_id'] = user_id
resp = requests.post(url,
headers=header,
data=json.dumps(data),
timeout=300)
if resp.status_code != 200:
raise Exception(str((resp.status_code, resp.reason)))
return resp.json()['reply'], resp.json()['references']
except Exception as e:
print(str(e))
return ''
def get_streaming_response(response: requests.Response):
for chunk in response.iter_lines(chunk_size=1024, decode_unicode=False,
delimiter=b"\0"):
if chunk:
pattern = re.compile(rb'data: "(\\u[0-9a-fA-F]{4})"')
matches = pattern.findall(chunk)
decoded_data = []
for match in matches:
hex_value = match[2:].decode('ascii')
char = chr(int(hex_value, 16))
decoded_data.append(char)
print(char, end="", flush=True)
def stream_query(query):
url = base_url % 'stream'
try:
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
# Add history to data
data = {
'query': query,
'history': []
}
resp = requests.get(url,
headers=headers,
data=json.dumps(data),
timeout=300,
verify=False,
stream=True)
get_streaming_response(resp)
except Exception as e:
print(str(e))
def parse_args():
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--query',
default='your query',
help='')
parser.add_argument('--user_id', default='')
parser.add_argument('--stream', action='store_true')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
if args.stream:
stream_query(args.query, args.user_id)
else:
reply, ref = query(args.query, args.user_id)
print('reply: {} \nref: {} '.format(reply, ref))
[default] [default]
work_dir=/path/to/your/ai/work_dir work_dir = /path/to/your/ai/work_dir
bind_port=8888 bind_port = 8000
mem_threshold=50 use_template = False
dcu_threshold=100 output_format = True
[feature_database] [feature_database]
reject_throttle=0.82 reject_throttle=0.61
embedding_model_path=/path/to/your/text2vec-large-chinese embedding_model_path=/path/to/your/text2vec-large-chinese
reranker_model_path=/path/to/your/bce-reranker-base_v1 reranker_model_path=/path/to/your/bce-reranker-base_v1
[llm] [model]
local_llm_path=/home/llama3/Llama3-8B-Chinese-Chat llm_service_address = http://127.0.0.1:8001
use_vllm=True local_service_address = http://127.0.0.1:8002
stream_chat=True cls_model_path = /path/of/classification
tensor_parallel_size=1 llm_model = /path/to/your/Llama3-8B-Chinese-Chat/
\ No newline at end of file local_model = /path/to/your/Finetune/
max_input_length = 1400
\ No newline at end of file
import os
import nltk
nltk.data.path.append('/home/zhangwq/project/whl/nltk/nltk_data-gh-pages/nltk_data')
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['NLTK_DATA'] = '/home/zhangwq/project/whl/nltk/nltk_data-gh-pages/nltk_data'
import base64
import argparse
import uuid
import re
import io
from PIL import Image
from IPython.display import HTML, display
from langchain_experimental.open_clip import OpenCLIPEmbeddings
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import FAISS
from unstructured.partition.pdf import partition_pdf
from langchain_core.messages import HumanMessage
from langchain_community.chat_models import ChatVertexAI
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableLambda
from loguru import logger
def plt_img_base64(img_base64):
# Create an HTML img tag with the base64 string as the source
image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
# Display the image by rendering the HTML
display(HTML(image_html))
def multi_modal_rag_chain(retriever):
"""
Multi-modal RAG chain
"""
# Multi-modal LLM
model = ChatVertexAI(
temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024
)
# RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| model
| StrOutputParser()
)
return chain
def img_prompt_func(data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = []
# Adding the text for analysis
text_message = {
"type": "text",
"text": (
"You are an AI scientist tasking with providing factual answers.\n"
"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
"Use this information to provide answers related to the user question. \n"
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
return [HumanMessage(content=messages)]
def split_image_text_types(docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
doc = resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(doc)
if len(b64_images) > 0:
return {"images": b64_images[:1], "texts": []}
return {"images": b64_images, "texts": texts}
def resize_base64_image(base64_string, size=(128, 128)):
"""
Resize an image encoded as a Base64 string
"""
# Decode the Base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# Resize the image
resized_img = img.resize(size, Image.LANCZOS)
# Save the resized image to a bytes buffer
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def is_image_data(b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xFF\xD8\xFF": "jpg",
b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def looks_like_base64(sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def create_multi_vector_retriever(
vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
"""
Create retriever that indexes summaries, but returns raw images or texts
"""
# Initialize the storage layer
store = InMemoryStore()
id_key = "doc_id"
# Create the multi-vector retriever
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Helper function to add documents to the vectorstore and docstore
def add_documents(retriever, doc_summaries, doc_contents):
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(doc_summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
# Add texts, tables, and images
# Check that text_summaries is not empty before adding
if text_summaries:
add_documents(retriever, text_summaries, texts)
# Check that table_summaries is not empty before adding
if table_summaries:
add_documents(retriever, table_summaries, tables)
# Check that image_summaries is not empty before adding
if image_summaries:
add_documents(retriever, image_summaries, images)
return retriever
def extract_elements_from_pdf(file_path: str, image_output_dir_path: str):
pdf_list = [os.path.join(file_path, file) for file in os.listdir(file_path) if file.endswith('.pdf')]
tables = []
texts = []
raw_pdf_elements = partition_pdf(
filename=pdf_list[0],
extract_images_in_pdf=True,
infer_table_structure=True,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=image_output_dir_path,
)
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
tables.append(str(element))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
texts.append(str(element))
return texts, tables
class Summary:
def __init__(self):
pass
def encode_image(self, image_path):
"""Getting the base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_summarize(self, img_base64, prompt):
"""Make image summary"""
model = ChatVertexAI(model_name="gemini-pro-vision", max_output_tokens=1024)
msg = model(
[
HumanMessage(
content=[
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
]
)
]
)
return msg.content
def generate_img_summaries(self, path):
"""
Generate summaries and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
# Prompt
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""
# Apply to images
for img_file in sorted(os.listdir(path)):
if img_file.endswith(".jpg"):
img_path = os.path.join(path, img_file)
base64_image = self.encode_image(img_path)
img_base64_list.append(base64_image)
image_summaries.append(self.image_summarize(base64_image, prompt))
return img_base64_list, image_summaries
def generate_text_summaries(self, texts, tables):
text_summaries = texts
table_summaries = tables
return text_summaries, table_summaries
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--file_path',
type=str,
default='/home/zhangwq/data/art_test/pdf',
help='')
parser.add_argument(
'--image_output_dir_path',
default='/home/zhangwq/data/art_test',
help='')
parser.add_argument(
'--query',
default='compare and contrast between mistral and llama2 across benchmarks and explain the reasoning in detail',
help='')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
summary = Summary()
texts, tables = extract_elements_from_pdf(file_path=args.file_path,
image_output_dir_path=args.image_output_dir_path)
text_summaries, table_summaries = summary.generate_text_summaries(texts, tables)
img_base64_list, image_summaries = summary.generate_img_summaries(args.file_path)
embeddings = OpenCLIPEmbeddings(
model_name="/home/zhangwq/model/CLIP_VIT", checkpoint="laion2b_s34b_b88k")
embeddings.client = embeddings.client.half()
vectorstore = FAISS(collection_name="mm_rag_mistral",
embedding_function=embeddings)
# Create retriever
retriever = create_multi_vector_retriever(
vectorstore,
text_summaries,
texts,
table_summaries,
tables,
image_summaries,
img_base64_list,
)
chain_multimodal_rag = multi_modal_rag_chain(retriever)
docs = retriever.get_relevant_documents(args.query, limit=1)
logger.info(docs[0])
chain_multimodal_rag.invoke(args.query)
\ No newline at end of file
import os
import time
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'
def infer_hf_chatglm(model_path, prompt): prompt = ''
'''transformers 推理 chatglm2''' model_path = ''
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto").half().cuda()
model = model.eval()
start_time = time.time()
generated_text, _ = model.chat(tokenizer, prompt, history=[])
print("chat time ", time.time()- start_time)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
def infer_hf_llama3(model_path, prompt):
'''transformers 推理 llama3'''
input_query = {"role": "user", "content": prompt}
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto")
input_ids = tokenizer.apply_chat_template(
[input_query,], add_generation_prompt=True, return_tensors="pt").to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=512,
do_sample=True,
temperature=1,
top_p=0.95,
)
response = outputs[0][input_ids.shape[-1]:]
generated_text = tokenizer.decode(response, skip_special_tokens=True)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
def infer_vllm_llama3(model_path, message, tp_size=1, max_model_len=1024):
'''vllm 推理 llama3'''
tokenizer = AutoTokenizer.from_pretrained(model_path)
messages = [{"role": "user", "content": message}]
print(f"Prompt: {messages!r}")
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
stop_token_ids=[tokenizer.eos_token_id])
llm = LLM(model=model_path, sampling_params = SamplingParams(temperature=1, top_p=0.95)
max_model_len=max_model_len, llm = LLM(model=model_path,
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
dtype="float16", tensor_parallel_size=1)
tensor_parallel_size=tp_size)
# generate answer
start_time = time.time()
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
print("total infer time", time.time() - start_time)
# results
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
def infer_vllm_chatglm(model_path, message, tp_size=1): outputs = llm.generate(prompt, sampling_params)
'''vllm 推理 chatglm2''' for output in outputs:
sampling_params = SamplingParams(temperature=1.0,
top_p=0.9,
max_tokens=1024)
llm = LLM(model=model_path,
trust_remote_code=True,
enforce_eager=True,
dtype="float16",
tensor_parallel_size=tp_size)
# generate answer
print(f"chatglm2 Prompt: {message!r}")
outputs = llm.generate(message, sampling_params=sampling_params)
# results
for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='')
parser.add_argument('--query', default="DCU是什么?", help='提问的问题.')
parser.add_argument('--use_hf', action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
is_llama = True if "llama" in args.model_path else False
print("Is llama", is_llama)
if args.use_hf:
# transformers
if is_llama:
infer_hf_llama3(args.model_path, args.query)
else:
infer_hf_chatglm(args.model_path, args.query)
else:
# vllm
if is_llama:
infer_vllm_llama3(args.model_path, args.query)
else:
infer_vllm_chatglm(args.model_path, args.query)
from .feature_database import FeatureDataBase, DocumentProcessor, DocumentName # noqa E401 from .feature_database import FeatureDataBase, DocumentProcessor, DocumentName # noqa E401
from .helper import TaskCode, ErrorCode, LogManager # noqa E401 from .helper import TaskCode, ErrorCode, LogManager # noqa E401
from .inferencer import LLMInference, InferenceWrapper # noqa E401 from .http_client import OpenAPIClient, ClassifyClient # noqa E401
from .retriever import CacheRetriever, Retriever, rag_retrieve# noqa E401 from .worker import Worker # noqa E401
from .worker import Worker, ChatAgent # noqa E401 \ No newline at end of file
\ No newline at end of file
...@@ -8,7 +8,7 @@ import hashlib ...@@ -8,7 +8,7 @@ import hashlib
import textract import textract
import shutil import shutil
import configparser import configparser
import json
from multiprocessing import Pool from multiprocessing import Pool
from typing import List from typing import List
from loguru import logger from loguru import logger
...@@ -18,7 +18,17 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter ...@@ -18,7 +18,17 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.faiss import FAISS from langchain.vectorstores.faiss import FAISS
from torch.cuda import empty_cache from torch.cuda import empty_cache
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from .retriever import CacheRetriever, Retriever from elastic_keywords_search import ElasticKeywordsSearch
from retriever import Retriever
def check_envs(args):
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
class DocumentName: class DocumentName:
...@@ -143,14 +153,7 @@ class DocumentProcessor: ...@@ -143,14 +153,7 @@ class DocumentProcessor:
text = re.sub(r'\n\s*\n', '\n\n', text) text = re.sub(r'\n\s*\n', '\n\n', text)
elif file_type == 'excel': elif file_type == 'excel':
text = [] text += self.read_excel(filepath)
df = pd.read_excel(filepath, header=None)
for row in df.index.values:
doc = dict()
doc['Que'] = df.iloc[row, 0]
doc['Ans'] = df.iloc[row, 1]
text.append(str(doc))
# text += self.read_excel(filepath)
elif file_type == 'word' or file_type == 'ppt': elif file_type == 'word' or file_type == 'ppt':
# https://stackoverflow.com/questions/36001482/read-doc-file-with-python # https://stackoverflow.com/questions/36001482/read-doc-file-with-python
...@@ -177,6 +180,17 @@ class DocumentProcessor: ...@@ -177,6 +180,17 @@ class DocumentProcessor:
return text, None return text, None
def read_excel(self, filepath: str):
table = None
if filepath.endswith('.csv'):
table = pd.read_csv(filepath)
else:
table = pd.read_excel(filepath)
if table is None:
return ''
json_text = table.dropna(axis=1).to_json(force_ascii=False)
return json_text
def read_pdf(self, filepath: str): def read_pdf(self, filepath: str):
# load pdf and serialize table # load pdf and serialize table
...@@ -231,7 +245,7 @@ class FeatureDataBase: ...@@ -231,7 +245,7 @@ class FeatureDataBase:
def __init__(self, def __init__(self,
embeddings: HuggingFaceEmbeddings, embeddings: HuggingFaceEmbeddings,
reranker: BCERerank, reranker: BCERerank,
reject_throttle: float) -> None: reject_throttle=-1) -> None:
# logger.debug('loading text2vec model..') # logger.debug('loading text2vec model..')
self.embeddings = embeddings self.embeddings = embeddings
...@@ -242,7 +256,7 @@ class FeatureDataBase: ...@@ -242,7 +256,7 @@ class FeatureDataBase:
self.reject_throttle = reject_throttle if reject_throttle else -1 self.reject_throttle = reject_throttle if reject_throttle else -1
self.text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=768, chunk_overlap=32) chunk_size=1068, chunk_overlap=32)
def get_documents(self, text, file): def get_documents(self, text, file):
# if len(text) <= 1: # if len(text) <= 1:
...@@ -256,12 +270,17 @@ class FeatureDataBase: ...@@ -256,12 +270,17 @@ class FeatureDataBase:
documents.append(chunk) documents.append(chunk)
return documents return documents
def register_response(self, files: list, work_dir: str, file_opr: DocumentProcessor): def build_database(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None):
feature_dir = os.path.join(work_dir, 'db_response') feature_dir = os.path.join(work_dir, 'db_response')
if not os.path.exists(feature_dir): if not os.path.exists(feature_dir):
os.makedirs(feature_dir) os.makedirs(feature_dir)
documents = [] documents = []
texts_for_es = []
metadatas_for_es = []
ids_for_es = []
for i, file in enumerate(files): for i, file in enumerate(files):
if not file.status: if not file.status:
continue continue
...@@ -273,58 +292,31 @@ class FeatureDataBase: ...@@ -273,58 +292,31 @@ class FeatureDataBase:
file.message = str(error) file.message = str(error)
continue continue
file.message = str(text[0]) file.message = str(text[0])
# file.message = str(len(text))
# logger.info('{} content length {}'.format( texts_for_es.append(text[0])
# file._category, len(text))) metadatas_for_es.append({'source': file.basename, 'read': file.origin_path})
ids_for_es.append(str(i))
document = self.get_documents(text, file) document = self.get_documents(text, file)
documents += document documents += document
logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents' logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents'
.format(i + 1, len(files), file.basename, len(document))) .format(i + 1, len(files), file.basename, len(document)))
logger.debug('Positive pipeline register {} documents into database...'.format(len(documents))) if elastic_search is not None:
time_before_register = time.time() logger.debug('ES database pipeline register {} documents into database...'.format(len(texts_for_es)))
vs = FAISS.from_documents(documents, self.embeddings)
vs.save_local(feature_dir)
time_after_register = time.time()
logger.debug('Positive pipeline take time: {} '.format(time_after_register - time_before_register))
def register_reject(self, files: list, work_dir: str, file_opr: DocumentProcessor):
feature_dir = os.path.join(work_dir, 'db_reject')
if not os.path.exists(feature_dir):
os.makedirs(feature_dir)
documents = []
for i, file in enumerate(files):
if not file.state:
continue
text, error = file_opr.read(file.copypath)
if len(text) < 1:
continue
if error is not None:
continue
document = self.get_documents(text, file) es_time_before_register = time.time()
documents += document elastic_search.add_texts(texts_for_es, metadatas=metadatas_for_es, ids=ids_for_es)
logger.debug('Negative pipeline {}/{}.. register 《{}》 and split {} documents' es_time_after_register = time.time()
.format(i + 1, len(files), file.basename, len(document))) logger.debug('ES database pipeline take time: {} '.format(es_time_after_register - es_time_before_register))
if len(documents) < 1:
return
logger.debug('Negative pipeline register {} documents into database...'.format(len(documents))) logger.debug('Vector database pipeline register {} documents into database...'.format(len(documents)))
time_before_register = time.time()
ve_time_before_register = time.time()
vs = FAISS.from_documents(documents, self.embeddings) vs = FAISS.from_documents(documents, self.embeddings)
vs.save_local(feature_dir) vs.save_local(feature_dir)
ve_time_after_register = time.time()
time_after_register = time.time() logger.debug('Vector database pipeline take time: {} '.format(ve_time_after_register - ve_time_before_register))
logger.debug('Negative pipeline take time: {} '.format(time_after_register - time_before_register))
def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor): def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor):
...@@ -343,7 +335,7 @@ class FeatureDataBase: ...@@ -343,7 +335,7 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = 'skip image' file.message = 'skip image'
elif file._category in ['pdf', 'word', 'ppt', 'html']: elif file._category in ['pdf', 'word', 'ppt', 'html', 'excel']:
# read pdf/word/excel file and save to text format # read pdf/word/excel file and save to text format
md5 = file_opr.md5(file.origin_path) md5 = file_opr.md5(file.origin_path)
file.copy_path = os.path.join(preproc_dir, file.copy_path = os.path.join(preproc_dir,
...@@ -363,7 +355,7 @@ class FeatureDataBase: ...@@ -363,7 +355,7 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = str(e) file.message = str(e)
elif file._category in ['json', 'excel']: elif file._category in ['json']:
file.status = True file.status = True
file.copy_path = file.origin_path file.copy_path = file.origin_path
file.message = 'preprocessed' file.message = 'preprocessed'
...@@ -385,11 +377,10 @@ class FeatureDataBase: ...@@ -385,11 +377,10 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = 'read error' file.message = 'read error'
def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor): def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None):
self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr) self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr)
self.register_response(files=files, work_dir=work_dir, file_opr=file_opr) self.build_database(files=files, work_dir=work_dir, file_opr=file_opr, elastic_search=elastic_search)
# self.register_reject(files=files, work_dir=work_dir, file_opr=file_opr)
def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor): def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor):
...@@ -408,6 +399,7 @@ class FeatureDataBase: ...@@ -408,6 +399,7 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = str(error) file.message = str(error)
continue continue
logger.info(str(len(text)), text, str(text[0]))
file.message = str(text[0]) file.message = str(text[0])
# file.message = str(len(text)) # file.message = str(len(text))
...@@ -416,9 +408,13 @@ class FeatureDataBase: ...@@ -416,9 +408,13 @@ class FeatureDataBase:
documents += self.get_documents(text, file) documents += self.get_documents(text, file)
if documents:
vs = FAISS.from_documents(documents, self.embeddings) vs = FAISS.from_documents(documents, self.embeddings)
if faiss:
faiss.merge_from(vs) faiss.merge_from(vs)
faiss.save_local(feature_dir) faiss.save_local(feature_dir)
else:
vs.save_local(feature_dir)
def test_reject(retriever: Retriever): def test_reject(retriever: Retriever):
...@@ -461,17 +457,21 @@ def parse_args(): ...@@ -461,17 +457,21 @@ def parse_args():
description='Feature store for processing directories.') description='Feature store for processing directories.')
parser.add_argument('--work_dir', parser.add_argument('--work_dir',
type=str, type=str,
default='/home/chat_demo/work_dir/', default='',
help='自定义.') help='自定义.')
parser.add_argument( parser.add_argument(
'--repo_dir', '--repo_dir',
type=str, type=str,
default='/home/chat_demo/work_dir/jifu/original', default='',
help='需要读取的文件目录.') help='需要读取的文件目录.')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/home/chat_demo/config.ini', default='./ai/rag/config.ini',
help='config目录') help='config目录')
parser.add_argument(
'--DCU_ID',
default=[7],
help='设置DCU')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -482,53 +482,45 @@ if __name__ == '__main__': ...@@ -482,53 +482,45 @@ if __name__ == '__main__':
log_file_path = os.path.join(args.work_dir, 'application.log') log_file_path = os.path.join(args.work_dir, 'application.log')
logger.add(log_file_path, rotation='10MB', compression='zip') logger.add(log_file_path, rotation='10MB', compression='zip')
check_envs(args)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
embedding_model_path = config['feature_database']['embedding_model_path'] # only init vector retriever
reranker_model_path = config['feature_database']['reranker_model_path'] retriever = Retriever(config)
reject_throttle = float(config['feature_database']['reject_throttle']) fs_init = FeatureDataBase(embeddings=retriever.embeddings,
reranker=retriever.reranker)
# init es retriever, drop_old means build new one or updata the 'index_name'
es_url = config.get('rag', 'es_url')
index_name = config.get('rag', 'index_name')
cache = CacheRetriever(embedding_model_path=embedding_model_path, elastic_search = ElasticKeywordsSearch(
reranker_model_path=reranker_model_path) elasticsearch_url=es_url,
fs_init = FeatureDataBase(embeddings=cache.embeddings, index_name=index_name,
reranker=cache.reranker, drop_old=True)
reject_throttle=reject_throttle)
# walk all files in repo dir # walk all files in repo dir
file_opr = DocumentProcessor() file_opr = DocumentProcessor()
files = file_opr.scan_directory(repo_dir=args.repo_dir) files = file_opr.scan_directory(repo_dir=args.repo_dir)
fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr) fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr, elastic_search=elastic_search)
file_opr.summarize(files) file_opr.summarize(files)
del fs_init # del fs_init
retriever = cache.get(reject_throttle=reject_throttle,
work_dir=args.work_dir)
# with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f: # with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f:
# positive_sample = json.load(f) # positive_sample = json.load(f)
# with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f: # with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f:
# negative_sample = json.load(f) # negative_sample = json.load(f)
#
with open(os.path.join(args.work_dir, 'jifu', 'positive.txt'), 'r', encoding='utf-8') as file: # with open(os.path.join(args.work_dir, 'sample', 'positive.txt'), 'r', encoding='utf-8') as file:
positive_sample = [] # positive_sample = []
for line in file: # for line in file:
positive_sample.append(line.strip()) # positive_sample.append(line.strip())
#
with open(os.path.join(args.work_dir, 'jifu', 'negative.txt'), 'r', encoding='utf-8') as file: # with open(os.path.join(args.work_dir, 'sample', 'negative.txt'), 'r', encoding='utf-8') as file:
negative_sample = [] # negative_sample = []
for line in file: # for line in file:
negative_sample.append(line.strip()) # negative_sample.append(line.strip())
#
reject_throttle = retriever.update_throttle(work_dir=args.work_dir, # test_reject(retriever)
config_path=args.config_path,
positive_sample=positive_sample,
negative_sample=negative_sample)
cache.pop('default')
# test
retriever = cache.get(reject_throttle=reject_throttle,
work_dir=args.work_dir)
test_reject(retriever)
import time
import random
import argparse
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.firefox.options import Options
from concurrent.futures import ThreadPoolExecutor, as_completed
from selenium.common.exceptions import TimeoutException, WebDriverException, NoSuchElementException
from bs4 import BeautifulSoup
from loguru import logger
class FetchBingResults:
def __init__(self,query):
self.query = query
def get_driver(self):
options = Options()
options.set_preference('permissions.default.image', 2) # 2表示禁止加载图片
options.add_argument('--headless') # 使用无头模式,不显示浏览器窗口
options.add_argument('--no-sandbox') # 禁用沙箱机制
options.add_argument('--disable-gpu') # 禁用GPU硬件加速
options.add_argument('--disable-dev-shm-usage') # 禁用 /dev/shm 的共享内存使用
options.add_argument('--user-agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"') # 设置用户代理
options.set_preference("intl.accept_languages", "en-US,en") # 设置语言
driver = webdriver.Firefox(options=options)
return driver
def fetch_bing_results(self, num = 1):
retries = 3
for _ in range(retries):
driver = self.get_driver()
driver.get(f'https://www.bing.com/search?q={self.query}')
try:
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.XPATH, '//li[@class="b_algo"]'))
)
html_content = driver.page_source
soup = BeautifulSoup(html_content, 'html.parser')
b_results = soup.find('ol', {'id': 'b_results'})
results = list(b_results.find_all('li', class_='b_algo'))
search_results = []
# 创建线程池
with ThreadPoolExecutor(max_workers = num) as executor:
# 创建字典
future_to_result = {
executor.submit(self.fetch_article_content, result.find('a')['href']): result
for result in results[:num]
}
for future in as_completed(future_to_result):
result = future_to_result[future]
try:
content, current_url = future.result()
title = result.find('h2').text
return content[:1000]
# search_results.append({'title': title, 'content': content, 'link': current_url})
except Exception as exc:
logger.error(f'Generated an exception: {exc}')
# return search_results
except (TimeoutException, WebDriverException) as e:
logger.error(f"Attempt {_ + 1} failed: {str(e)}")
time.sleep(random.uniform(1, 3)) # 等待一段随机时间后重试
finally:
driver.quit()
logger.error("All retries failed.")
return ''
def fetch_article_content(self,link):
retries = 3
for _ in range(retries):
driver = self.get_driver()
driver.get(link)
try:
try:
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.LINK_TEXT, 'Please click here if the page does not redirect automatically ...'))
).click()
except TimeoutException:
logger.error("No redirection link found.")
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, 'body'))
)
# 执行 JavaScript 以确保所有动态内容加载完成
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
# 使用新的动态等待方法来检查页面内容
WebDriverWait(driver, 10).until(
lambda d: d.execute_script('return document.readyState') == 'complete'
)
article_page_source = driver.page_source
article_soup = BeautifulSoup(article_page_source, 'html.parser')
# 提取页面内容
content = article_soup.get_text(strip=True)
# 获取当前页面的URL
current_url = driver.current_url
return content,current_url
except (TimeoutException, WebDriverException, NoSuchElementException) as e:
logger.error(f"Attempt {_ + 1} failed: {str(e)}")
time.sleep(random.uniform(1, 3)) # 等待一段随机时间后重试
finally:
driver.quit()
return None, link
def __del__(self):
self.driver.quit()
def parse_args():
parser = argparse.ArgumentParser(
description='')
parser.add_argument(
'--query',
default='介绍下曙光DCU',
help='提问的问题.')
args = parser.parse_args()
return args
def main():
args = parse_args()
fetch_bing = FetchBingResults(args)
results = fetch_bing.fetch_bing_results()
print(results)
if __name__ == "__main__":
main()
\ No newline at end of file
...@@ -30,7 +30,7 @@ class ErrorCode(Enum): ...@@ -30,7 +30,7 @@ class ErrorCode(Enum):
SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota' SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota'
NOT_FIND_RELATED_DOCS = 16, 'No relevant documents found, the following answer is generated directly by LLM.' NOT_FIND_RELATED_DOCS = 16, 'No relevant documents found, the following answer is generated directly by LLM.'
NON_COMPLIANCE_QUESTION = 17, 'Non-compliance question, refusing to answer.' NON_COMPLIANCE_QUESTION = 17, 'Non-compliance question, refusing to answer.'
SCORE_ERROR = 18, 'Get score error.' NO_WEB_SEARCH_RESULT = 18, 'Can not fetch result from web.'
def __new__(cls, value, description): def __new__(cls, value, description):
"""Create new instance of ErrorCode.""" """Create new instance of ErrorCode."""
......
import os
import time
import json
import httpx
import configparser
import torch
import numpy as np
from loguru import logger
from BCEmbedding.tools.langchain import BCERerank
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from sklearn.metrics import precision_recall_curve
from transformers import BertForSequenceClassification, BertTokenizer
def build_history_messages(prompt, history, system: str = None):
history_messages = []
if system is not None and len(system) > 0:
history_messages.append({'role': 'system', 'content': system})
for item in history:
history_messages.append({'role': 'user', 'content': item[0]})
history_messages.append({'role': 'assistant', 'content': item[1]})
history_messages.append({'role': 'user', 'content': prompt})
return history_messages
class OpenAPIClient:
def __init__(self, url: str, model_name):
self.url = '%s/v1/chat/completions' % url
self.model_name = model_name
async def get_streaming_response(self, headers, data):
async with httpx.AsyncClient() as client:
async with client.stream("POST", self.url, json=data, headers=headers, timeout=300) as response:
async for line in response.aiter_lines():
if not line or 'DONE' in line:
continue
try:
result = json.loads(line.split('data:')[1])
output = result['choices'][0]['delta'].get('content')
except Exception as e:
logger.error('Model response parse failed:', e)
raise Exception('Model response parse failed.')
if not output:
continue
yield output
async def get_response(self, headers, data):
async with httpx.AsyncClient() as client:
resp = await client.post(self.url, json=data, headers=headers, timeout=300)
try:
result = json.loads(resp.content.decode("utf-8"))
output = result['choices'][0]['message']['content']
except Exception as e:
logger.error('Model response parse failed:', e)
raise Exception('Model response parse failed.')
return output
async def chat(self, prompt: str, history=[], stream=False):
header = {'Content-Type': 'application/json'}
# Add history to data
data = {
"model": self.model_name,
"messages": build_history_messages(prompt, history),
"stream": stream
}
logger.info("Request openapi param: {}".format(data))
if stream:
return self.get_streaming_response(header, data)
else:
return await self.get_response(header, data)
class ClassifyModel:
def __init__(self, model_path, dcu_id):
logger.info("Starting initial bert class model")
self.cls_model = BertForSequenceClassification.from_pretrained(model_path).float().cuda()
self.cls_model.load_state_dict(torch.load(os.path.join(model_path, 'bert_cls_model.pth')))
self.cls_model.eval()
self.cls_tokenizer = BertTokenizer.from_pretrained(model_path)
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_id
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_id}")
def classfication(self, sentence):
inputs = self.cls_tokenizer(
sentence,
max_length=512,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to('cuda')
with torch.no_grad():
outputs = self.cls_model(**inputs)
logits = outputs[0]
score = torch.max(logits.data, 1)[1].tolist()
logger.info("分类结果: {}, {}".format(score[0], sentence))
return float(score[0])
class CacheRetriever:
def __init__(self, embedding_model_path: str, reranker_model_path: str, max_len: int = 4):
self.cache = dict()
self.max_len = max_len
# load text2vec and rerank model
logger.info('loading test2vec and rerank models')
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cuda'},
encode_kwargs={
'batch_size': 1,
'normalize_embeddings': True
})
# half
self.embeddings.client = self.embeddings.client.half()
reranker_args = {
'model': reranker_model_path,
'top_n': 3,
'device': 'cuda',
'use_fp16': True
}
self.reranker = BCERerank(**reranker_args)
def get(self,
reject_throttle: float,
fs_id: str = 'default',
work_dir='workdir'
):
if fs_id in self.cache:
self.cache[fs_id]['time'] = time.time()
return self.cache[fs_id]['retriever']
if len(self.cache) >= self.max_len:
# drop the oldest one
del_key = None
min_time = time.time()
for key, value in self.cache.items():
cur_time = value['time']
if cur_time < min_time:
min_time = cur_time
del_key = key
if del_key is not None:
del_value = self.cache[del_key]
self.cache.pop(del_key)
del del_value['retriever']
retriever = Retriever(embeddings=self.embeddings,
reranker=self.reranker,
work_dir=work_dir,
reject_throttle=reject_throttle)
self.cache[fs_id] = {'retriever': retriever, 'time': time.time()}
return retriever
def pop(self, fs_id: str):
if fs_id not in self.cache:
return
del_value = self.cache[fs_id]
self.cache.pop(fs_id)
# manually free memory
del del_value
class Retriever:
def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None:
self.reject_throttle = reject_throttle
self.rejecter = None
self.retriever = None
self.compression_retriever = None
self.embeddings = embeddings
self.reranker = reranker
self.rejecter = FAISS.load_local(
os.path.join(work_dir, 'db_response'),
embeddings=embeddings,
allow_dangerous_deserialization=True)
self.vector_store = FAISS.load_local(
os.path.join(work_dir, 'db_response'),
embeddings=embeddings,
allow_dangerous_deserialization=True,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
self.retriever = self.vector_store.as_retriever(
search_type='similarity',
search_kwargs={
'score_threshold': self.reject_throttle,
'k': 30
}
)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=self.retriever)
if self.rejecter is None:
logger.warning('rejecter is None')
if self.retriever is None:
logger.warning('retriever is None')
def is_relative(self, sample, k=30, disable_throttle=False):
"""If no search results below the threshold can be found from the
database, reject this query."""
if self.rejecter is None:
return False, []
if disable_throttle:
# for searching throttle during update sample
docs_with_score = self.rejecter.similarity_search_with_relevance_scores(
sample, k=1)
if len(docs_with_score) < 1:
return False, docs_with_score
return True, docs_with_score
else:
# for retrieve result
# if no chunk passed the throttle, give the max
docs_with_score = self.rejecter.similarity_search_with_relevance_scores(
sample, k=k)
ret = []
max_score = -1
top1 = None
for (doc, score) in docs_with_score:
if score >= self.reject_throttle:
ret.append(doc)
if score > max_score:
max_score = score
top1 = (doc, score)
relative = True if len(ret) > 0 else False
return relative, [top1]
def update_throttle(self,
work_dir: str,
config_path: str = 'config.ini',
positive_sample=[],
negative_sample=[]):
import matplotlib.pyplot as plt
"""Update reject throttle based on positive and negative examples."""
if len(positive_sample) == 0 or len(negative_sample) == 0:
raise Exception('positive and negative samples cat not be empty.')
all_samples = positive_sample + negative_sample
predictions = []
for sample in all_samples:
self.reject_throttle = -1
_, docs = self.is_relative(sample=sample,
disable_throttle=True)
score = docs[0][1]
predictions.append(max(0, score))
labels = [1 for _ in range(len(positive_sample))
] + [0 for _ in range(len(negative_sample))]
precision, recall, thresholds = precision_recall_curve(
labels, predictions)
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, label='Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='best')
plt.grid(True)
plt.savefig(os.path.join(work_dir, 'precision_recall_curve.png'), format='png')
plt.close()
logger.debug("Figure have been saved!")
thresholds = np.append(thresholds, 1)
max_precision = np.max(precision)
indices_with_max_precision = np.where(precision == max_precision)
optimal_recall = recall[indices_with_max_precision[0][0]]
optimal_threshold = thresholds[indices_with_max_precision[0][0]]
logger.debug(f"Optimal threshold with the highest recall under the highest precision is: {optimal_threshold}")
logger.debug(f"The corresponding precision is: {max_precision}")
logger.debug(f"The corresponding recall is: {optimal_recall}")
config = configparser.ConfigParser()
config.read(config_path)
config['feature_database']['reject_throttle'] = str(optimal_threshold)
with open(config_path, 'w') as configfile:
config.write(configfile)
logger.info(
f'Update optimal threshold: {optimal_threshold} to {config_path}' # noqa E501
)
return optimal_threshold
def query(self,
question: str,
):
time_1 = time.time()
if question is None or len(question) < 1:
return None, None, []
if len(question) > 512:
logger.warning('input too long, truncate to 512')
question = question[0:512]
chunks = []
references = []
relative, docs = self.is_relative(sample=question)
if relative:
docs = self.compression_retriever.get_relevant_documents(question)
for doc in docs:
doc = [doc.page_content]
chunks.append(doc)
# chunks = [doc.page_content for doc in docs]
references = [doc.metadata['source'] for doc in docs]
time_2 = time.time()
logger.debug('query:{} \nchunks:{} \nreferences:{} \ntimecost:{}'
.format(question, chunks, references, time_2 - time_1))
return chunks, [os.path.basename(r) for r in references]
else:
if len(docs) > 0:
references.append(docs[0][0].metadata['source'])
logger.info('feature database rejected!')
return chunks, references
import time
import os
import configparser
import argparse
import json
import asyncio
import uuid
from loguru import logger
from aiohttp import web
# from multiprocessing import Value
from transformers import AutoModelForCausalLM, AutoTokenizer
COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
"<官网>": "https://www.sugon.com/after_sale/policy?sh=1",
"<平台联系方式>": "1、访问官网,根据您所在地地址联系平台人员,网址地址:https://www.sugon.com/about/contact;\n2、点击人工客服进行咨询;\n3、请您拨打中科曙光服务热线400-810-0466联系人工进行咨询。",
"<购买与维修的咨询方法>": "1、确定付费处理,可以微信搜索'sugon中科曙光服务'小程序,选择'在线报修'业务\n2、先了解价格,可以微信搜索'sugon中科曙光服务'小程序,选择'其他咨询'业务\n3、请您拨打中科曙光服务热线400-810-0466",
"<服务器续保流程>": "1、微信搜索'sugon中科曙光服务'小程序,选择'延保与登记'业务\n2、点击人工客服进行登记\n3、请您拨打中科曙光服务热线400-810-0466根据语音提示选择维保与购买",
"<XC内外网OS网盘链接>": "【腾讯文档】XC内外网OS网盘链接:https://docs.qq.com/sheet/DTWtXbU1BZHJvWkJm",
"<W360-G30机器,安装Win7使用的镜像链接>": "W360-G30机器,安装Win7使用的镜像链接:https://pan.baidu.com/s/1SjHqCP6kJ9KzdJEBZDEynw;提取码:x6m4",
"<麒麟系统搜狗输入法下载链接>": "软件下载链接(百度云盘):链接:https://pan.baidu.com/s/18Iluvs4BOAfFET0yFMBeLQ,提取码:bhkf",
"<X660 G45 GPU服务器拆解视频网盘链接>": "链接: https://pan.baidu.com/s/1RkRGh4XY1T2oYftGnjLp4w;提取码: v2qi",
"<DS800,SANTRICITY存储IBM版本模拟器网盘链接>": "链接:https://pan.baidu.com/s/1euG9HGbPfrVbThEB8BX76g;提取码:o2ya",
"<E80-D312(X680-G55)风冷整机组装说明下载链接>": "链接:https://pan.baidu.com/s/17KDpm-Z9lp01WGp9sQaQ4w;提取码:0802",
"<X680 G55 风冷相关资料下载链接>": "链接:https://pan.baidu.com/s/1KQ-hxUIbTWNkc0xzrEQLjg;提取码:0802",
"<R620 G51刷写EEPROM下载>": "下载链接如下:http://10.2.68.104/tools/bytedance/eeprom/",
"<X7450A0服务器售后培训文件网盘链接>": "网盘下载:https://pan.baidu.com/s/1tZJIf_IeQLOWsvuOawhslQ?pwd=kgf1;提取码:kgf1",
"<福昕阅读器补丁链接>": "补丁链接: https://pan.baidu.com/s/1QJQ1kHRplhhFly-vxJquFQ,提取码: aupx1",
"<W330-H35A_22DB4/W3335HA安装win7网盘链接>": "硬盘链接: https://pan.baidu.com/s/1fDdGPH15mXiw0J-fMmLt6Q提取码: k97i",
"<X680 G55服务器售后培训资料网盘链接>": "云盘连接下载:链接:https://pan.baidu.com/s/1gaok13DvNddtkmk6Q-qLYg?pwd=xyhb提取码:xyhb",
"<展厅管理员>": "北京-穆淑娟18001053012\n天津-马书跃15720934870\n昆山-关天琪15304169908\n成都-贾小芳18613216313\n重庆-李子艺17347743273\n安阳-郭永军15824623085\n桐乡-李梦瑶18086537055\n青岛-陶祉伊15318733259",
"<线上预约展厅>": "北京、天津、昆山、成都、重庆、安阳、桐乡、青岛",
"<马华>": "联系人:马华,电话:13761751980,邮箱:china@pinbang.com",
"<梁静>": "联系人:梁静,电话:18917566297,邮箱:ing.liang@omaten.com",
"<徐斌>": "联系人:徐斌,电话:13671166044,邮箱:244898943@qq.com",
"<俞晓枫>": "联系人:俞晓枫,电话13750869272,邮箱:857233013@qq.com",
"<刘广鹏>": "联系人:刘广鹏,电话13321992411,邮箱:liuguangpeng@pinbang.com",
"<马英伟>": "联系人:马英伟,电话:13260021849,邮箱:13260021849@163.com",
"<杨洋>": "联系人:杨洋,电话15801203938,邮箱bing523888@163.com",
"<展会合规要求>": "1.展品内容:展品内容需符合公司合规要求,展示内容需经过法务合规审查。\n2.文字材料内容:文字材料内容需符合公司合规要求,展示内容需经过法务合规审查。\n3.展品标签:展品标签内容需符合公司合规要求。\n4.礼品内容:礼品内容需符合公司合规要求。\n5.视频内容:视频内容需符合公司合规要求,展示内容需经过法务合规审查。\n6.讲解词内容:讲解词内容需符合公司合规要求,展示内容需经过法务合规审查。\n7.现场发放材料:现场发放的材料内容需符合公司合规要求。\n8.展示内容:整体展示内容需要经过法务合规审查。",
"<展会质量>": "1.了解展会的组织者背景、往届展会的评价以及提供的服务支持,确保展会的专业性和高效性。\n.了解展会的规模、参观人数、行业影响力等因素,以判断展会是否能够提供足够的曝光度和商机。\n3.关注同行业其他竞争对手是否参展,以及他们的展位布置、展示内容等信息,以便制定自己的参展策略。\n4.展会的日期是否与公司的其他重要活动冲突,以及举办地点是否便于客户和合作伙伴的参观。\n5.销售部门会询问展会方提供的宣传渠道和推广服务,以及如何利用这些资源来提升公司及产品的知名度。\n6.记录展会期间的重要领导参观、商机线索、合作洽谈、公司拜访预约等信息,跟进后续商业机会。",
"<摊位费规则>": "根据展位面积大小,支付相应费用。\n展位照明费:支付展位内的照明服务费。\n展位保安费:支付展位内的保安服务费。\n展位网络使用费:支付展位内网络使用的费用。\n展位电源使用费:支付展位内电源使用的费用。",
"<展会主题要求>": "展会主题的确定需要符合公司产品和服务业务范围,以确保能够吸引目标客户群体。因此,确定展会主题时,需要考虑以下因素:\n专业性:展会的主题应确保专业性,符合行业特点和目标客户的需求。\n目标客户群体:展会的主题定位应考虑目标客户群体,确保能够吸引他们的兴趣。\n业务重点:展会的主题应突出公司的业务重点和优势,以便更好地推广公司的核心产品或服务。\n行业影响力:展会的主题定位需要考虑行业的最新发展趋势,以凸显公司的行业地位和影响力。\n往届展会经验:可以参考往届展会的主题定位,总结经验教训,以确定本届展会的主题。\n市场部意见:在确定展会主题时,应听取市场部的意见,确保主题符合公司的整体市场战略。\n领导意见:还需要考虑公司领导的意见,以确保展会主题符合公司的战略发展方向。",
"<办理展商证注意事项>": "人员范围:除公司领导和同事需要办理展商证外,展会运营工作人员也需要办理。\n提前准备:展商证的办理需要提前进行,以确保摄影师、摄像师等工作人员可以提前入场进行布置。\n办理流程:需要熟悉展商证的办理流程,准备好相关材料,如身份证件等。\n数量需求:需要评估所需的展商证数量,避免数量不足或过多的情况。\n有效期限:展商证的有效期限需要注意,避免在展期内过期。\n存放安全:办理完的展商证需要妥善保管,避免丢失或被他人使用。\n使用规范:使用展商证时需要遵守展会相关规定,不得转让给他人使用。\n回收处理:展会结束后,需要及时回收展商证,避免泄露相关信息。",
"<项目单价要求>": "请注意:无论是否年框供应商,项目单价都不得超过采购部制定的“2024常见活动项目标准单价”,此报价仅可内部使用,严禁外传",
"<年框供应商细节表格>": "在线表格https://kdocs.cn/l/camwZE63frNw",
"<年框供应商流程>": "1.需求方发出项目需求(大型项目需比稿)\n2.外协根据项目需求报价,提供需求方“预算单”(按照基准单价报价,如有发现不按单价情况,解除合同不再使用)\n3.需求方确认预算价格,并提交OA市场活动申请\n4.外协现场执行\n5.需求方现场验收,并签署验收单(物料、设备、人员等实际清单)\n6.外协出具结算单(金额与验收单一致,加盖公章)、结案报告、年框合同,作为报销凭证\n7.外协请需求方项目负责人填写“满意度调研表”(如无,会影响年度评价)\n8.需求方项目经理提交报销",
"<市场活动结案报告内容>": "1.项目简介(时间、地点、参与人数等);2.最终会议安排;3.活动各环节现场图片;4.费用相关证明材料(如执行人员、物料照片);5.活动成效汇总;6.活动原始照片/视频网络链接",
"<展板设计选择>": "1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计",
"<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》",
"":"",
}
def build_history_messages(prompt, history, system: str = None):
history_messages = []
if system is not None and len(system) > 0:
history_messages.append({'role': 'system', 'content': system})
for item in history:
history_messages.append({'role': 'user', 'content': item[0]})
history_messages.append({'role': 'assistant', 'content': item[1]})
history_messages.append({'role': 'user', 'content': prompt})
return history_messages
def substitution(output_text):
# 翻译特殊字符
import re
if isinstance(output_text, list):
output_text = output_text[0]
matchObj = re.split('.*(<.*>).*', output_text, re.M|re.I)
if len(matchObj) > 1:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
return output_text
class LLMInference:
def __init__(self,
model,
tokenizer,
device: str = 'cuda',
) -> None:
self.device = device
self.model = model
self.tokenizer = tokenizer
def generate_response(self, prompt, history=[]):
print("generate")
output_text = ''
error = ''
time_tokenzier = time.time()
try:
output_text = self.chat(prompt, history)
except Exception as e:
error = str(e)
logger.error(error)
time_finish = time.time()
logger.debug('output_text:{} \ntimecost {} '.format(output_text,
time_finish - time_tokenzier))
return output_text, error
def chat(self, messages, history=[]):
'''单轮问答'''
logger.info("****************** in chat ******************")
try:
# transformers
input_ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
outputs = self.model.generate(
input_ids,
max_new_tokens=1024,
)
response = outputs[0][input_ids.shape[-1]:]
generated_text = self.tokenizer.decode(response, skip_special_tokens=True)
output_text = substitution(generated_text)
logger.info(f"using transformers, output_text {output_text}")
return output_text
except Exception as e:
logger.error(f"chat inference failed, {e}")
def chat_stream(self, messages, history=[]):
'''流式服务'''
# HuggingFace
logger.info("****************** in chat stream *****************")
current_length = 0
logger.info(f"stream_chat messages {messages}")
for response, _, _ in self.model.stream_chat(self.tokenizer, messages, history=history,
max_length=1024,
past_key_values=None,
return_past_key_values=True):
output_text = response[current_length:]
output_text = substitution(output_text)
logger.info(f"using transformers chat_stream, Prompt: {messages!r}, Generated text: {output_text!r}")
yield output_text
current_length = len(response)
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
logger.info("Starting initial model of LLM")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if use_vllm:
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
early_stopping=False,
stop_token_ids=[tokenizer.eos_token_id]
)
# vLLM基础配置
args = AsyncEngineArgs(model_path)
args.worker_use_ray = False
args.engine_use_ray = False
args.tokenizer = model_path
args.tensor_parallel_size = tensor_parallel_size
args.trust_remote_code = True
args.enforce_eager = True
args.max_model_len = 1024
args.dtype = 'float16'
# 加载模型
engine = AsyncLLMEngine.from_engine_args(args)
return engine, tokenizer, sampling_params
else:
# huggingface
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval()
return model, tokenizer, None
def hf_inference(bind_port, model, tokenizer, stream_chat):
'''启动 hf Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
llm_infer = LLMInference(model, tokenizer)
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use transformers ******************")
if stream_chat:
text = await asyncio.to_thread(llm_infer.chat_stream, messages=messages, history=history)
else:
text = await asyncio.to_thread(llm_infer.chat, messages=messages, history=history)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, time.time() - start))
return web.json_response({'text': text})
app = web.Application()
app.add_routes([web.post('/hf_inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def vllm_inference(bind_port, model, tokenizer, sampling_params):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
# history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
## generate template
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
logger.info(f"The input_text is {input_text}")
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id)
# Non-streaming case
logger.info("****************** in chat ******************")
final_output = None
async for request_output in results_generator:
# if await request.is_disconnected():
# # Abort the request if the client disconnects.
# await engine.abort(request_id)
# return Response(status_code=499)
final_output = request_output
assert final_output is not None
text = [output.text for output in final_output.outputs]
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
app = web.Application()
app.add_routes([web.post('/vllm_inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def vllm_inference_stream(bind_port, model, tokenizer, sampling_params):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
async def inference(request):
input_json = await request.json()
prompt = input_json['query']
# history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
## generate template
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
logger.info(f"The input_text is {input_text}")
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
logger.info("****************** in stream chat ******************")
response = web.StreamResponse()
await response.prepare(request)
async for request_output in results_generator:
text_outputs = [output.text for output in request_output.outputs]
await response.write((json.dumps({"text": text_outputs})+"\0").encode("utf-8"))
response.write_eof()
return response
app = web.Application()
app.add_routes([web.post('/vllm_inference_stream', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
def parse_args():
'''参数'''
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--config_path',
default='../config.ini',
help='config目录')
parser.add_argument(
'--query',
default='写一首诗',
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
type=str,
default='6',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args()
return args
def main():
args = parse_args()
set_envs(args.DCU_ID)
# configs
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenizer, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
if use_vllm:
if stream_chat:
vllm_inference_stream(bind_port, model, tokenizer, sampling_params)
else:
vllm_inference(bind_port, model, tokenizer, sampling_params)
else:
hf_inference(bind_port, model, tokenizer, stream_chat)
if __name__ == '__main__':
main()
import time
import os
import configparser
import argparse
# import torch
import asyncio
import uuid
from typing import AsyncGenerator
from loguru import logger
from aiohttp import web
# from multiprocessing import Value
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
COMMON = { COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline", "<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
...@@ -54,144 +39,3 @@ COMMON = { ...@@ -54,144 +39,3 @@ COMMON = {
"<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》", "<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》",
"":"", "":"",
} }
\ No newline at end of file
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
# huggingface
logger.info("Starting initial model of Llama - vllm")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# vllm
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
early_stopping=False,
stop_token_ids=[tokenizer.eos_token_id]
)
# vLLM基础配置
args = AsyncEngineArgs(model_path)
args.worker_use_ray = False
args.engine_use_ray = False
args.tokenizer = model_path
args.tensor_parallel_size = tensor_parallel_size
args.trust_remote_code = True
args.enforce_eager = True
args.max_model_len = 1024
args.dtype = 'float16'
# 加载模型
engine = AsyncLLMEngine.from_engine_args(args)
return engine, tokenizer, sampling_params
def llm_inference(args):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
use_vllm = config.getboolean('llm', 'use_vllm')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenizer, sampling_params = init_model(model_path, tensor_parallel_size)
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
logger.info(f"before generate {messages}")
## 1
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
print(text)
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator:
final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs})
assert final_output is not None
return [output.text for output in final_output.outputs]
if stream_chat:
logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results())
text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
# Non-streaming case
final_output = None
async for request_output in results_generator:
# if await request.is_disconnected():
# # Abort the request if the client disconnects.
# await engine.abort(request_id)
# return Response(status_code=499)
final_output = request_output
assert final_output is not None
text = [output.text for output in final_output.outputs]
end = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text})
app = web.Application()
app.add_routes([web.post('/inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
def parse_args():
'''参数'''
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--config_path',
default='../config.ini',
help='config目录')
parser.add_argument(
'--query',
default='写一首诗',
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
type=str,
default='4',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args()
return args
def main():
args = parse_args()
set_envs(args.DCU_ID)
llm_inference(args)
if __name__ == '__main__':
main()
import time
import os
import pickle
from loguru import logger from loguru import logger
from .helper import ErrorCode, LogManager from .helper import ErrorCode
from .retriever import CacheRetriever from .http_client import OpenAPIClient, ClassifyModel, CacheRetriever
from .inferencer import LLMInference
from .feature_database import DocumentProcessor, FeatureDataBase
SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501
GENERATE_TEMPLATE = '<Data>{}</Data> \n 回答要求:\n如果你不清楚答案,你需要澄清。\n避免提及你是从 <Data></Data> 获取的知识。\n保持答案与 <Data></Data> 中描述的一致。\n使用 Markdown 语法优化回答格式。\n使用与问题相同的语言回答。问题:"{}"'
MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题'
COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
"<官网>": "https://www.sugon.com/after_sale/policy?sh=1",
"<平台联系方式>": "1、访问官网,根据您所在地地址联系平台人员,网址地址:https://www.sugon.com/about/contact;\n2、点击人工客服进行咨询;\n3、请您拨打中科曙光服务热线400-810-0466联系人工进行咨询。",
"<购买与维修的咨询方法>": "1、确定付费处理,可以微信搜索'sugon中科曙光服务'小程序,选择'在线报修'业务\n2、先了解价格,可以微信搜索'sugon中科曙光服务'小程序,选择'其他咨询'业务\n3、请您拨打中科曙光服务热线400-810-0466",
"<服务器续保流程>": "1、微信搜索'sugon中科曙光服务'小程序,选择'延保与登记'业务\n2、点击人工客服进行登记\n3、请您拨打中科曙光服务热线400-810-0466根据语音提示选择维保与购买",
"<XC内外网OS网盘链接>": "【腾讯文档】XC内外网OS网盘链接:https://docs.qq.com/sheet/DTWtXbU1BZHJvWkJm",
"<W360-G30机器,安装Win7使用的镜像链接>": "W360-G30机器,安装Win7使用的镜像链接:https://pan.baidu.com/s/1SjHqCP6kJ9KzdJEBZDEynw;提取码:x6m4",
"<麒麟系统搜狗输入法下载链接>": "软件下载链接(百度云盘):链接:https://pan.baidu.com/s/18Iluvs4BOAfFET0yFMBeLQ,提取码:bhkf",
"<X660 G45 GPU服务器拆解视频网盘链接>": "链接: https://pan.baidu.com/s/1RkRGh4XY1T2oYftGnjLp4w;提取码: v2qi",
"<DS800,SANTRICITY存储IBM版本模拟器网盘链接>": "链接:https://pan.baidu.com/s/1euG9HGbPfrVbThEB8BX76g;提取码:o2ya",
"<E80-D312(X680-G55)风冷整机组装说明下载链接>": "链接:https://pan.baidu.com/s/17KDpm-Z9lp01WGp9sQaQ4w;提取码:0802",
"<X680 G55 风冷相关资料下载链接>": "链接:https://pan.baidu.com/s/1KQ-hxUIbTWNkc0xzrEQLjg;提取码:0802",
"<R620 G51刷写EEPROM下载>": "下载链接如下:http://10.2.68.104/tools/bytedance/eeprom/",
"<X7450A0服务器售后培训文件网盘链接>": "网盘下载:https://pan.baidu.com/s/1tZJIf_IeQLOWsvuOawhslQ?pwd=kgf1;提取码:kgf1",
"<福昕阅读器补丁链接>": "补丁链接: https://pan.baidu.com/s/1QJQ1kHRplhhFly-vxJquFQ,提取码: aupx1",
"<W330-H35A_22DB4/W3335HA安装win7网盘链接>": "硬盘链接: https://pan.baidu.com/s/1fDdGPH15mXiw0J-fMmLt6Q提取码: k97i",
"<X680 G55服务器售后培训资料网盘链接>": "云盘连接下载:链接:https://pan.baidu.com/s/1gaok13DvNddtkmk6Q-qLYg?pwd=xyhb提取码:xyhb",
"<展厅管理员>": "北京-穆淑娟18001053012\n天津-马书跃15720934870\n昆山-关天琪15304169908\n成都-贾小芳18613216313\n重庆-李子艺17347743273\n安阳-郭永军15824623085\n桐乡-李梦瑶18086537055\n青岛-陶祉伊15318733259",
"<线上预约展厅>": "北京、天津、昆山、成都、重庆、安阳、桐乡、青岛",
"<马华>": "联系人:马华,电话:13761751980,邮箱:china@pinbang.com",
"<梁静>": "联系人:梁静,电话:18917566297,邮箱:ing.liang@omaten.com",
"<徐斌>": "联系人:徐斌,电话:13671166044,邮箱:244898943@qq.com",
"<俞晓枫>": "联系人:俞晓枫,电话13750869272,邮箱:857233013@qq.com",
"<刘广鹏>": "联系人:刘广鹏,电话13321992411,邮箱:liuguangpeng@pinbang.com",
"<马英伟>": "联系人:马英伟,电话:13260021849,邮箱:13260021849@163.com",
"<杨洋>": "联系人:杨洋,电话15801203938,邮箱bing523888@163.com",
"<展会合规要求>": "1.展品内容:展品内容需符合公司合规要求,展示内容需经过法务合规审查。\n2.文字材料内容:文字材料内容需符合公司合规要求,展示内容需经过法务合规审查。\n3.展品标签:展品标签内容需符合公司合规要求。\n4.礼品内容:礼品内容需符合公司合规要求。\n5.视频内容:视频内容需符合公司合规要求,展示内容需经过法务合规审查。\n6.讲解词内容:讲解词内容需符合公司合规要求,展示内容需经过法务合规审查。\n7.现场发放材料:现场发放的材料内容需符合公司合规要求。\n8.展示内容:整体展示内容需要经过法务合规审查。",
"<展会质量>": "1.了解展会的组织者背景、往届展会的评价以及提供的服务支持,确保展会的专业性和高效性。\n.了解展会的规模、参观人数、行业影响力等因素,以判断展会是否能够提供足够的曝光度和商机。\n3.关注同行业其他竞争对手是否参展,以及他们的展位布置、展示内容等信息,以便制定自己的参展策略。\n4.展会的日期是否与公司的其他重要活动冲突,以及举办地点是否便于客户和合作伙伴的参观。\n5.销售部门会询问展会方提供的宣传渠道和推广服务,以及如何利用这些资源来提升公司及产品的知名度。\n6.记录展会期间的重要领导参观、商机线索、合作洽谈、公司拜访预约等信息,跟进后续商业机会。",
"<摊位费规则>": "根据展位面积大小,支付相应费用。\n展位照明费:支付展位内的照明服务费。\n展位保安费:支付展位内的保安服务费。\n展位网络使用费:支付展位内网络使用的费用。\n展位电源使用费:支付展位内电源使用的费用。",
"<展会主题要求>": "展会主题的确定需要符合公司产品和服务业务范围,以确保能够吸引目标客户群体。因此,确定展会主题时,需要考虑以下因素:\n专业性:展会的主题应确保专业性,符合行业特点和目标客户的需求。\n目标客户群体:展会的主题定位应考虑目标客户群体,确保能够吸引他们的兴趣。\n业务重点:展会的主题应突出公司的业务重点和优势,以便更好地推广公司的核心产品或服务。\n行业影响力:展会的主题定位需要考虑行业的最新发展趋势,以凸显公司的行业地位和影响力。\n往届展会经验:可以参考往届展会的主题定位,总结经验教训,以确定本届展会的主题。\n市场部意见:在确定展会主题时,应听取市场部的意见,确保主题符合公司的整体市场战略。\n领导意见:还需要考虑公司领导的意见,以确保展会主题符合公司的战略发展方向。",
"<办理展商证注意事项>": "人员范围:除公司领导和同事需要办理展商证外,展会运营工作人员也需要办理。\n提前准备:展商证的办理需要提前进行,以确保摄影师、摄像师等工作人员可以提前入场进行布置。\n办理流程:需要熟悉展商证的办理流程,准备好相关材料,如身份证件等。\n数量需求:需要评估所需的展商证数量,避免数量不足或过多的情况。\n有效期限:展商证的有效期限需要注意,避免在展期内过期。\n存放安全:办理完的展商证需要妥善保管,避免丢失或被他人使用。\n使用规范:使用展商证时需要遵守展会相关规定,不得转让给他人使用。\n回收处理:展会结束后,需要及时回收展商证,避免泄露相关信息。",
"<项目单价要求>": "请注意:无论是否年框供应商,项目单价都不得超过采购部制定的“2024常见活动项目标准单价”,此报价仅可内部使用,严禁外传",
"<年框供应商细节表格>": "在线表格https://kdocs.cn/l/camwZE63frNw",
"<年框供应商流程>": "1.需求方发出项目需求(大型项目需比稿)\n2.外协根据项目需求报价,提供需求方“预算单”(按照基准单价报价,如有发现不按单价情况,解除合同不再使用)\n3.需求方确认预算价格,并提交OA市场活动申请\n4.外协现场执行\n5.需求方现场验收,并签署验收单(物料、设备、人员等实际清单)\n6.外协出具结算单(金额与验收单一致,加盖公章)、结案报告、年框合同,作为报销凭证\n7.外协请需求方项目负责人填写“满意度调研表”(如无,会影响年度评价)\n8.需求方项目经理提交报销",
"<市场活动结案报告内容>": "1.项目简介(时间、地点、参与人数等);2.最终会议安排;3.活动各环节现场图片;4.费用相关证明材料(如执行人员、物料照片);5.活动成效汇总;6.活动原始照片/视频网络链接",
"<展板设计选择>": "1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计",
"<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》",
"":"",
}
def substitution(chunks):
# 翻译特殊字符
import re
new_chunks = []
for chunk in chunks:
matchObj = re.split('.*(<.*>).*', chunk, re.M|re.I)
if len(matchObj) > 1:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
chunk = chunk.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {chunk}")
new_chunks.append(chunk)
return new_chunks
class ChatAgent: class Worker:
def __init__(self, config, tensor_parallel_size) -> None: def __init__(self, config):
self.work_dir = config['default']['work_dir'] self.work_dir = config['default']['work_dir']
self.embedding_model_path = config['feature_database']['embedding_model_path'] llm_model = config['model']['llm_model']
self.reranker_model_path = config['feature_database']['reranker_model_path'] local_model = config['model']['local_model']
llm_service_address = config['model']['llm_service_address']
cls_model_path = config['model']['cls_model_path']
local_server_address = config['model']['local_service_address']
reject_throttle = float(config['feature_database']['reject_throttle']) reject_throttle = float(config['feature_database']['reject_throttle'])
local_llm_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
stream_chat = config.getboolean('llm', 'stream_chat')
self.retriever = CacheRetriever(self.embedding_model_path, if not llm_service_address:
raise Exception('llm_service_address is required in config.ini')
if not cls_model_path:
raise Exception('cls_model_path is required in config.ini')
self.max_input_len = int(config['model']['max_input_length'])
self.retriever = CacheRetriever(
self.embedding_model_path,
self.reranker_model_path).get(reject_throttle=reject_throttle, self.reranker_model_path).get(reject_throttle=reject_throttle,
work_dir=self.work_dir) work_dir=self.work_dir)
self.llm_server = LLMInference(model_path=local_llm_path, tensor_parallel_size=tensor_parallel_size, use_vllm=use_vllm, stream_chat=stream_chat) self.openapi_service = OpenAPIClient(llm_service_address, llm_model)
self.openapi_local_server = OpenAPIClient(local_server_address, local_model)
self.classify_service = ClassifyModel(cls_model_path)
self.tasks = {}
if os.path.exists(self.work_dir + '/tasks_status.pkl'):
with open(self.work_dir + '/tasks_status.pkl', 'rb') as f:
self.tasks = pickle.load(f)
def generate_prompt(self, def generate_prompt(self,
history_pair, history_pair,
instruction: str, instruction: str,
template: str,
context: str = ''): context: str = ''):
if context is not None and len(context) > 0: if context is not None and len(context) > 0:
instruction = template.format(context, instruction) str_context = str(context)
if len(str_context) > self.max_input_len:
str_context = str_context[:self.max_input_len]
instruction = GENERATE_TEMPLATE.format(str_context, instruction)
real_history = [] real_history = []
for pair in history_pair: for pair in history_pair:
...@@ -40,118 +120,76 @@ class ChatAgent: ...@@ -40,118 +120,76 @@ class ChatAgent:
return instruction, real_history return instruction, real_history
def call_rag_retrieve(self, query): async def generater(self, content):
return self.retriever.query(query) for word in content:
yield word
def call_llm_response(self, prompt, history=None): #await asyncio.sleep(0.1)
text, error = self.llm_server.generate_response(prompt=prompt, history=history)
return text async def response_by_common(self, query, history, output_format=False, stream=False):
if output_format:
def parse_file_and_merge(self, file_dir): query = MARKDOWN_TEMPLATE.format(query)
file_opr = DocumentProcessor() logger.info('Prompt is: {}, History is: {}'.format(query, history))
files = file_opr.scan_directory(repo_dir=file_dir) response_direct = await self.openapi_service.chat(query, history, stream=stream)
file_handler = FeatureDataBase(embeddings=self.retriever.embeddings, reranker=self.retriever.reranker) return response_direct
file_handler.preprocess(files=files, work_dir=self.work_dir, file_opr=file_opr)
file_handler.merge_db_response(self.retriever.vector_store, files=files, work_dir=self.work_dir, file_opr=file_opr) def format_rag_result(self, chunks, references, stream=False):
file_opr.summarize(files) result = "针对您的问题,我们找到了如下解决方案:\n%s"
self.retriever = CacheRetriever(self.embedding_model_path, self.reranker_model_path).get(work_dir=self.work_dir) content = ""
for i, item in enumerate(references):
if item.endswith(".json"):
class Worker: content += " - %s.%s\n" % (i + 1, chunks[i])
else:
def __init__(self, config, tensor_parallel_size): line = chunks[i]
self.agent = ChatAgent(config, tensor_parallel_size) if len(line) > 300:
self.TOPIC_TEMPLATE = '告诉我这句话的主题,直接说主题不要解释:“{}”' line = line[:300] + "..." + '\n'
self.SCORING_RELAVANCE_TEMPLATE = '问题:“{}”\n材料:“{}”\n请仔细阅读以上内容,材料里为一个列表,列表里面有若干子列表,请判断每个子列表的内容和问题的相关度,不要解释直接给出相关度得分列表并以空格分隔,用0~10表示。判断标准:非常相关得 10 分;完全没关联得 0 分。\n' # noqa E501 line += "详细内容参见:%s" % item
self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。' # noqa E501 content += " - %s.%s\n" % (i + 1, line)
self.SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501 if stream:
self.PERPLESITY_TEMPLATE = '“question:{} answer:{}”\n阅读以上对话,answer 是否在表达自己不知道,回答越全面得分越少,用0~10表示,不要解释直接给出得分。\n判断标准:准确回答问题得 0 分;答案详尽得 1 分;知道部分答案但有不确定信息得 8 分;知道小部分答案但推荐求助其他人得 9 分;不知道任何答案直接推荐求助别人得 10 分。直接打分不要解释。' # noqa E501 return self.generater((result % content))
self.SUMMARIZE_TEMPLATE = '{} \n 仔细阅读以上内容,总结简短有力点' # noqa E501 return result % content
self.GENERATE_TEMPLATE = '“{}” \n问题:“{}” \n请仔细阅读上述文字, 并使用markdown格式回答问题,直接给出回答不做任何解释。' # noqa E501
self.MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题' def response_by_finetune(self, query, history=[]):
'''微调模型回答'''
def judgment_results(self, query, chunks, throttle): logger.info('Prompt is: {}, History is: {}'.format(query, history))
response_direct = self.openapi_local_server.chat(query, history)
relation_score = self.agent.call_llm_response( return response_direct
prompt=self.SCORING_RELAVANCE_TEMPLATE.format(query, chunks))
logger.info('score: %s' % [relation_score, throttle]) async def produce_response(self, config, query, history, stream=False):
# 过滤操作
filtered_chunks = []
for chunk, score in zip(chunks, relation_score.split()):
if float(score) >= float(throttle):
filtered_chunks.append(chunk)
return filtered_chunks
def extract_topic(self, query):
topic = self.agent.call_llm_response(self.TOPIC_TEMPLATE.format(query))
return topic
def response_direct_by_llm(self, query):
# Compliant check
import ast
prompt = self.SECURITY_TEMAPLTE.format(query)
scores = self.agent.call_llm_response(prompt=prompt)
try:
score_list = ast.literal_eval(scores)
score = int(score_list[0])
except Exception as e:
logger.error("score:{}, error:{}".format(score, e))
return ErrorCode.SCORE_ERROR, e, None
logger.debug("score:{}, prompt:{}".format(score, prompt))
if int(score) > 5:
return ErrorCode.NON_COMPLIANCE_QUESTION, "您的问题中涉及敏感话题,请重新提问。", None
logger.info('LLM direct response and prompt is: {}'.format(query))
prompt = self.MARKDOWN_TEMPLATE.format(query)
response_direct = self.agent.call_llm_response(prompt=prompt)
return ErrorCode.NOT_FIND_RELATED_DOCS, response_direct, None
def produce_response(self, query,
history,
judgment,
topic=False,
rag=True):
response = '' response = ''
references = [] references = []
use_template = config.getboolean('default', 'use_template')
output_format = config.getboolean('default', 'output_format')
if query is None: if query is None:
return ErrorCode.NOT_A_QUESTION, response, references return ErrorCode.NOT_A_QUESTION, response, references
logger.info('input: %s' % [query, history]) logger.info('input: %s' % [query, history])
if rag:
if topic:
query = self.extract_topic(query)
logger.info('topic: %s' % query)
if len(query) <= 0:
return ErrorCode.NO_TOPIC, response, references
chunks, references = self.agent.call_rag_retrieve(query) # classify
score = self.classify_service.classfication(query)
if score > 0.8:
logger.debug('Start RAG search')
chunks, references = self.retriever.query(query)
if len(chunks) == 0: if len(chunks) == 0:
return self.response_direct_by_llm(query) logger.debug('Response by finetune model')
chunks = [self.response_by_finetune(query, history=history)]
if judgment: elif use_template:
chunks = self.judgment_results( logger.debug('Response by template')
query, chunks, response = self.format_rag_result(chunks, references, stream=stream)
throttle=5, return ErrorCode.SUCCESS, response, references
)
# 如果DataBase检索到了,就用检索到的块去回答 logger.debug('Response with common model')
if len(chunks) > 0: new_chunks = substitution(chunks)
prompt, history = self.agent.generate_prompt( prompt, history = self.generate_prompt(
instruction=query, instruction=query,
context=chunks, context=new_chunks,
history_pair=history, history_pair=history)
template=self.GENERATE_TEMPLATE)
logger.debug('prompt: {}'.format(prompt)) logger.debug('prompt: {}'.format(prompt))
response = self.agent.call_llm_response(prompt=prompt, history=history) response = await self.response_by_common(prompt, history=history, output_format=False, stream=stream)
return ErrorCode.SUCCESS, response, references return ErrorCode.SUCCESS, response, references
else: else:
return self.response_direct_by_llm(query) logger.debug('Response by common model')
response = await self.response_by_common(query, history=history, output_format=output_format, stream=stream)
return ErrorCode.SUCCESS, response, references
...@@ -6,37 +6,36 @@ from loguru import logger ...@@ -6,37 +6,36 @@ from loguru import logger
from llm_service import Worker, llm_inference from llm_service import Worker, llm_inference
def set_envs(dcu_ids): def check_envs(args):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
def parse_args(): def parse_args():
"""Parse args.""" """Parse args."""
parser = argparse.ArgumentParser(description='Executor.') parser = argparse.ArgumentParser(description='Executor.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
type=str, default=[1,2,6,7],
default='0', help='设置DCU')
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/path/of/config.ini', default='/path/to/your/ai/config.ini',
type=str, type=str,
help='config.ini路径') help='config.ini路径')
parser.add_argument( parser.add_argument(
'--standalone', '--standalone',
default=False, default=False,
help='部署LLM推理服务') help='部署LLM推理服务.')
parser.add_argument( parser.add_argument(
'--use_vllm', '--accelerate',
default=False, default=False,
type=bool, type=bool,
help='是否启用LLM推理加速' help='LLM推理是否启用加速'
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -55,7 +54,7 @@ def build_reply_text(reply: str, references: list): ...@@ -55,7 +54,7 @@ def build_reply_text(reply: str, references: list):
def reply_workflow(assistant): def reply_workflow(assistant):
queries = ['我们公司想要购买几台测试机,请问需要联系哪位?'] queries = ['你好,我们公司想要购买几台测试机,请问需要联系贵公司哪位?']
for query in queries: for query in queries:
code, reply, references = assistant.produce_response(query=query, code, reply, references = assistant.produce_response(query=query,
history=[], history=[],
...@@ -67,19 +66,19 @@ def run(): ...@@ -67,19 +66,19 @@ def run():
args = parse_args() args = parse_args()
if args.standalone is True: if args.standalone is True:
import time import time
set_envs(args) check_envs(args)
server_ready = Value('i', 0) server_ready = Value('i', 0)
server_process = Process(target=llm_inference, server_process = Process(target=llm_inference,
args=(args.config_path, args=(args.config_path,
len(args.DCU_ID), len(args.DCU_ID),
args.use_vllm, args.accelerate,
server_ready)) server_ready))
server_process.daemon = True server_process.daemon = True
server_process.start() server_process.start()
while True: while True:
if server_ready.value == 0: if server_ready.value == 0:
logger.info('waiting for server to be ready.') logger.info('waiting for server to be ready..')
time.sleep(15) time.sleep(15)
elif server_ready.value == 1: elif server_ready.value == 1:
break break
......
[default]
work_dir = /path/to/your/ai/work_dir
bind_port = 8003
[rag]
embedding_model_path = /path/to/your/text2vec-large-chinese
reranker_model_path = /path/to/your/bce-reranker-base_v1
vector_top_k = 5
es_top_k = 5
es_url = http://10.2.106.50:31920
index_name = dcu_knowledge_base
from __future__ import annotations
import argparse
import uuid
from abc import ABC
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
from loguru import logger
import jieba.analyse
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
if TYPE_CHECKING:
from elasticsearch import Elasticsearch # noqa: F401
def _default_text_mapping() -> Dict:
return {'properties': {'text': {'type': 'text'}}}
DEFAULT_PROMPT = PromptTemplate(
input_variables=['question'],
template="""分析给定Question,提取Question中包含的KeyWords,输出列表形式
Examples:
Question: 达梦公司在过去三年中的流动比率如下:2021年:3.74倍;2020年:2.82倍;2019年:2.05倍。
KeyWords: ['过去三年', '流动比率', '2021', '3.74', '2020', '2.82', '2019', '2.05']
----------------
Question: {question}
KeyWords: """,
)
class ElasticKeywordsSearch(VectorStore, ABC):
def __init__(
self,
elasticsearch_url: str,
index_name: str,
drop_old: Optional[bool] = False,
*,
ssl_verify: Optional[Dict[str, Any]] = None,
llm_chain: Optional[LLMChain] = None,
):
try:
import elasticsearch
except ImportError:
logger.error('Could not import elasticsearch python package. '
'Please install it with `pip install elasticsearch`.')
return
self.index_name = index_name
self.llm_chain = llm_chain
self.drop_old = drop_old
_ssl_verify = ssl_verify or {}
self.elasticsearch_url = elasticsearch_url
self.ssl_verify = _ssl_verify
try:
self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify)
except ValueError as e:
logger.error(f'Your elasticsearch client string is mis-formatted. Got error: {e}')
return
if drop_old:
try:
self.client.indices.delete(index=index_name)
except elasticsearch.exceptions.NotFoundError:
logger.info(f"Index '{index_name}' not found, nothing to delete.")
except Exception as e:
logger.error(f"Error occurred while trying to delete index '{index_name}': {e}")
logger.info(f"ElasticKeywordsSearch initialized with URL: {elasticsearch_url} and index: {index_name}")
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
refresh_indices: bool = True,
**kwargs: Any,
) -> List[str]:
try:
from elasticsearch.exceptions import NotFoundError
from elasticsearch.helpers import bulk
except ImportError:
raise ImportError('Could not import elasticsearch python package. '
'Please install it with `pip install elasticsearch`.')
requests = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
mapping = _default_text_mapping()
# check to see if the index already exists
try:
self.client.indices.get(index=self.index_name)
if texts and self.drop_old:
self.client.indices.delete(index=self.index_name)
self.create_index(self.client, self.index_name, mapping)
except NotFoundError:
# TODO would be nice to create index before embedding,
# just to save expensive steps for last
self.create_index(self.client, self.index_name, mapping)
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
request = {
'_op_type': 'index',
'_index': self.index_name,
'text': text,
'metadata': metadata,
'_id': ids[i],
}
requests.append(request)
bulk(self.client, requests)
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
return ids
def similarity_search(self,
query: str,
k: int = 4,
query_strategy: str = 'match_phrase',
must_or_should: str = 'should',
**kwargs: Any) -> List[Document]:
if k == 0:
# pm need to control
return []
docs_and_scores = self.similarity_search_with_score(query,
k=k,
query_strategy=query_strategy,
must_or_should=must_or_should,
**kwargs)
documents = [d[0] for d in docs_and_scores]
return documents
@staticmethod
def _relevance_score_fn(distance: float) -> float:
"""Normalize the distance to a score on a scale [0, 1]."""
# Todo: normalize the es score on a scale [0, 1]
return distance
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._relevance_score_fn
def similarity_search_with_score(self,
query: str,
k: int = 4,
query_strategy: str = 'match_phrase',
must_or_should: str = 'should',
**kwargs: Any) -> List[Tuple[Document, float]]:
if k == 0:
# pm need to control
return []
assert must_or_should in ['must', 'should'], 'only support must and should.'
# llm or jiaba extract keywords
if self.llm_chain:
keywords_str = self.llm_chain.run(query)
print('llm search keywords:', keywords_str)
try:
keywords = eval(keywords_str)
if not isinstance(keywords, list):
raise ValueError('Keywords extracted by llm is not list.')
except Exception as e:
print(str(e))
keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False)
else:
keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False)
logger.info('jieba search keywords:{}'.format(keywords))
match_query = {'bool': {must_or_should: []}}
for key in keywords:
match_query['bool'][must_or_should].append({query_strategy: {'text': key}})
response = self.client_search(self.client, self.index_name, match_query, size=k)
hits = [hit for hit in response['hits']['hits']]
docs_and_scores = [
Document(
page_content=hit['_source']['text'],
metadata={**hit['_source']['metadata'], 'relevance_score': hit['_score']}
)
for hit in hits]
return docs_and_scores
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
index_name: Optional[str] = None,
refresh_indices: bool = True,
llm: Optional[BaseLLM] = None,
prompt: Optional[PromptTemplate] = DEFAULT_PROMPT,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> ElasticKeywordsSearch:
elasticsearch_url = get_from_dict_or_env(kwargs, 'elasticsearch_url', 'ELASTICSEARCH_URL')
if 'elasticsearch_url' in kwargs:
del kwargs['elasticsearch_url']
index_name = index_name or uuid.uuid4().hex
if llm:
llm_chain = LLMChain(llm=llm, prompt=prompt)
vectorsearch = cls(elasticsearch_url,
index_name,
llm_chain=llm_chain,
drop_old=drop_old,
**kwargs)
else:
vectorsearch = cls(elasticsearch_url, index_name, drop_old=drop_old, **kwargs)
vectorsearch.add_texts(texts,
metadatas=metadatas,
ids=ids,
refresh_indices=refresh_indices)
return vectorsearch
def create_index(self, client: Any, index_name: str, mapping: Dict) -> None:
version_num = client.info()['version']['number'][0]
version_num = int(version_num)
if version_num >= 8:
client.indices.create(index=index_name, mappings=mapping)
else:
client.indices.create(index=index_name, body={'mappings': mapping})
def client_search(self, client: Any, index_name: str, script_query: Dict, size: int) -> Any:
version_num = client.info()['version']['number'][0]
version_num = int(version_num)
if version_num >= 8:
response = client.search(index=index_name, query=script_query, size=size, timeout='5s')
else:
response = client.search(index=index_name, body={'query': script_query, 'size': size}, timeout='5s')
return response
def delete(self, **kwargs: Any) -> None:
# TODO: Check if this can be done in bulk
self.client.indices.delete(index=self.index_name)
def read_text(filepath):
with open(filepath) as f:
txt = f.read()
return txt
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--elasticsearch_url',
type=str,
default='http://127.0.0.1:9200')
parser.add_argument(
'--index_name',
type=str,
default='dcu_knowledge_base')
parser.add_argument(
'--query',
type=str,
default='介绍下K100_AI?')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
elastic_search = ElasticKeywordsSearch(
elasticsearch_url=args.elasticsearch_url,
index_name=args.index_name,
drop_old=False
)
# texts = []
# file_list = ['/home/zhangwq/data/doc_new/preprocess/dbaa1604.text',
# '/home/zhangwq/data/doc_new/preprocess/a8e2e50d.text',
# '/home/zhangwq/data/doc_new/preprocess/a3fdf916.text',
# '/home/zhangwq/data/doc_new/preprocess/9d2683f3.text',
# '/home/zhangwq/data/doc_new/preprocess/2584c250.text']
# for file in file_list:
# text = read_text(file)
# texts.append(text)
# metadatas = [
# {"source": "白皮书-K100.pdf", "type": "text"},
# {"source": "DCU人工智能基础软件系统DAS1.0介绍.pdf", "type": "text"},
# {"source": "202404-DCU优势测试项.pdf", "type": "text"},
# {"source": "202301-达芬奇架构简介.pdf", "type": "text"},
# {"source": "曙光DCU在大模型方面的布局与应用.docx", "type": "text"},
# ]
# ids = ["doc1", "doc2", "doc3", "doc4", "doc5"]
# elastic_search.add_texts(texts, metadatas=metadatas, ids=ids)
search_results = elastic_search.similarity_search_with_score(args.query, k=5)
for result in search_results:
logger.debug('Query: {} \nDoc: {} \nScore: {}'.format(args.query, result[0], result[1]))
import matplotlib.pyplot as plt
import numpy as np
import os
from matplotlib import font_manager
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用黑体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 数据
pollution_levels = ['原始', '0.5倍', '1倍', '1.5倍', '2倍']
vector_search = {
'hit_rate': [91.24, 95.99, 92.80, 86.45, 72.10],
'mrr': [97.66, 96.26, 95.67, 94.98, 96.23]
}
keyword_search = {
'hit_rate': [98.77, 98.28, 96.86, 93.93, 84.14],
'mrr': [96.97, 96.02, 94.03, 92.18, 86.23]
}
hybrid_search = {
'hit_rate': [99.80, 99.62, 99.37, 94.49, 94.10],
'mrr': [94.14, 94.55, 92.09, 89.42, 82.12]
}
hybrid_search_after_decoupling = {
'hit_rate': [99.71, 99.48, 99.04, 97.89, 92.19],
'mrr': [95.88, 94.12, 91.65, 87.90, 79.86]
}
# 创建图表
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(24, 6))
# fig.suptitle('搜索方法比较实验', fontsize=20)
# 函数来绘制单个子图
def plot_subplot(ax, data, title):
ax.plot(pollution_levels, data['hit_rate'], 'ro-', label='命中率')
ax.plot(pollution_levels, data['mrr'], 'bo-', label='平均倒数排名')
ax.set_title(title, fontsize=20)
ax.set_ylim(70, 100)
ax.set_xlabel('干扰程度', fontsize=18)
ax.set_ylabel('百分比', fontsize=18)
ax.legend(fontsize=16)
ax.grid(True, linestyle='--', alpha=0.7)
ax.tick_params(axis='both', which='major', labelsize=16)
# 设置x轴标签旋转
ax.set_xticklabels(pollution_levels, ha='right')
# 绘制子图
plot_subplot(ax1, vector_search, '向量搜索')
plot_subplot(ax2, keyword_search, '关键字搜索')
plot_subplot(ax3, hybrid_search, '混合搜索')
plot_subplot(ax4, hybrid_search_after_decoupling, '混合搜索_架构拆解')
# 调整布局
plt.tight_layout()
# 保存图表
file_path = os.path.join(os.getcwd(), 'search_methods_comparison.png')
plt.savefig(file_path, dpi=300, bbox_inches='tight')
print(f"图表已保存为: {file_path}")
# 检查文件是否成功创建
if os.path.exists(file_path):
print(f"文件成功创建在: {file_path}")
print(f"文件大小: {os.path.getsize(file_path)} 字节")
else:
print(f"文件创建失败,请检查权限或其他问题")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment