Commit 6088e14e authored by chenych's avatar chenych
Browse files

jifu v1.0

parent 2397728d
# chat_demo ## 树博士:一种基于检索增强生成模型的大型智能客服机器人系统
技服智能问答服务
“树博士”是一个基于 RAG结合LLM 的领域知识助手。特点:
## 环境配置
### Docker(方式一) 1. 应对垂直领域复杂应用场景,解答用户问题的同时,不会产生“幻觉”
-v 路径、docker_name和imageID根据实际情况修改 2. 提出一套解答技术问题的算法 pipeline
```bash 3. 模块化组合部署成本低,安全可靠鲁棒性强
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
# 加载运行环境变量 ## 步骤1. 环境配置
unzip dtk-cuda.zip -d /opt/dtk/ - 拉取镜像并创建容器:[光源镜像下载地址](https://sourcefind.cn/#/image/dcu/pytorch?activeName=overview)
source /opt/dtk/cuda/env.sh docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
# 下载fastllm库 - 运行容器:
git clone http://developer.hpccube.com/codes/OpenDAS/fastllm.git docker run -dit --name assitant --privileged --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal -v /path/to/model:/opt/model:ro image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310
# 编译fastllm 注意替换-v参数,即宿主机模型存放位置
cd fastllm - 安装
mkdir build ```shell
cd build git clone http://10.6.10.68/aps/ai.git
cmake .. cd ai
make -j
# 编译完成后,可以使用如下命令安装简易python工具包 # 如果是centos就用yum
cd tools # 这时在fastllm/build/tools目录下 apt update
python setup.py install apt install python-dev libxml2-dev libxslt1-dev antiword unrtf poppler-utils pstotext tesseract-ocr flac ffmpeg lame libmad0 libsox-fmt-mp3 sox libjpeg-dev swig libpulse-dev
cd /path/of/chat_demo 下载并安装dtk所需其他包:[faiss](http://10.6.10.68:8000/release/faiss/dtk24.04/faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl)
pip install faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl wget http://10.6.10.68:8000/release/faiss/dtk24.04/faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl
pip install -r requirements.txt pip install faiss-1.7.2_dtk22.10_gitb7348e7df780-py3-none-any.whl
```
安装火狐浏览器和驱动并配置环境变量
sudo apt-get update
sudo apt-get install firefox # 安装 Firefox 浏览器
wget https://github.com/mozilla/geckodriver/releases/latest/download/geckodriver-v0.34.0-linux64.tar.gz # 下载 geckodriver
tar -xvzf geckodriver-v0.34.0-linux64.tar.gz # 解压 geckodriver
sudo mv geckodriver /usr/local/bin/ # 将 geckodriver 移动到 /usr/local/bin 目录下
nano ~/.bashrc #配置环境变量,编辑 ~/.bashrc 文件
export PATH=$PATH:/usr/local/bin
export GECKODRIVER=/usr/local/bin/geckodriver # 添加以上两行内容
source ~/.bashrc # 保存并退出编辑器,然后让更改立即生效:
echo $GECKODRIVER # 验证环境变量
chmod +x /usr/local/bin/geckodriver #常见问题排查 权限问题:确保 geckodriver 有执行权限。如果没有,使用 chmod +x /usr/local/bin/geckodriver 赋予执行权限。
安装程序依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
### Dockerfile(方式二) ## 步骤2. 准备模型
首次运行树博士需手动下载相关模型到本地,下载地址:
[BCERerank下载地址](https://modelscope.cn/models/maidalun/bce-reranker-base_v1/files)
[Text2vec-large-chinese下载地址](https://modelscope.cn/models/Jerry0/text2vec-large-chinese/files)
[Llama3-8B-Chinese-Chat下载地址](https://hf-mirror.com/shenzhi-wang/Llama3-8B-Chinese-Chat)
bert-finetune-dcu模型为微调模型,可以联系开发获取
注意下载位置要与步骤一中的-v参数一致
## 步骤3. 修改配置文件:
ai/config.ini
```shell
[default]
work_dir = /path/to/your/ai/work_dir #填写ai/work_dir的绝对路径
bind_port = 8000 #填写服务对外暴露的他很害怕端口
summarize_query = False #查询前是否使用对问题进行总结
use_rag = True #是否使用RAG查询
use_template = False #是否使用模板输出RAG结果
output_format = True #是否使用Markdown格式化输出结果
filter_illegal = True #是否检查敏感问题
[feature_database]
file_processor = 10 #文件解析线程数
[llm]
llm_service_address = http://127.0.0.1:8001 #通用模型访问地址
llm_model = /path/to/your/Llama3-8B-Chinese-Chat/ #通用模型名称
cls_service_address = http://127.0.0.1:8002 #分类模型访问地址
rag_service_address = http://127.0.0.1:8003 #RAG服务访问地址
max_input_length = 1400 #输入长度限制
``` ```
docker build -t chat_demo:latest .
docker run -dit --network=host --name=chat_demo --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 chat_demo:latest
docker exec -it chat_demo /bin/bash
# 其他步骤同上面的Docker(方式一) ai/rag/config.ini
```shell
[default]
work_dir = /path/to/your/ai/work_dir #填写ai/work_dir的绝对路径
bind_port = 8003 #填写服务对外暴露的他很害怕端口
[rag]
embedding_model_path = /path/to/your/text2vec-large-chinese #填写text2vec-large-chinese模型的目录所在绝对路径
reranker_model_path = /path/to/your/bce-reranker-base_v1 #填写bce-reranker-base_v1模型的目录所在绝对路径
vector_top_k = 5 #向量库查询数量
es_top_k = 5 #es查询数量
es_url = http://10.2.106.50:31920 #es访问地址
index_name = dcu_knowledge_base #es索引名称
``` ```
### Conda(方法三) 向量库需联系开发人员获取,放置到work_dir下
关于本项目DCU显卡所需的工具包、深度学习库等均可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```bash ```shell
DTK驱动: dtk24.04 ll ai/work_dir/db_response/
python: python3.10 total 173696
torch: 2.1.0 -rw-r--r-- 1 root root 154079277 Aug 22 10:04 index.faiss
-rw-r--r-- 1 root root 23767932 Aug 22 10:04 index.pkl
``` ```
`Tips:以上dtk驱动、python、deepspeed等工具版本需要严格一一对应。`
3. 其它依赖库参照requirements.txt安装:
## 步骤4. 运行
- 运行llm服务
指定服务运行的dcu卡编号,尽量找空闲卡
export CUDA_VISIBLE_DEVICES='1'
使用vllm运行标准openai接口服务
python -m vllm.entrypoints.openai.api_server --model /opt/model/Llama3-8B-Chinese-Chat/ --enforce-eager --dtype float16 --trust-remote-code --port 8001 > llm.log 2>&1 &
- 运行分类模型
python ai/classify/classify.py --model_path /opt/model/bert-finetune-dcu/ --port 8002 --dcu_id 2 > classify.log 2>&1 &
参数说明:
--model_path 分类模型路径
--port 分类服务端口
--dcu_id 运行dcu卡编号
- 运行RAG服务
python ai/rag/retriever.py --config_path /root/ai/rag/config.ini --dcu_id 3 > rag.log 2>&1 &
提示:前三个服务都需要dcu设备,因此可以后台运行在一个容器中
- 运行助手服务
python ai/server_start.py --config_path /root/ai/config.ini --log_path /var/log/assistant.log
提示:此服务不依赖dcu设备,可以运行在k8s中
- 客户端调用示例:
```shell
python client.py --action query --query "你好, 我们公司想要购买几台测试机, 请问需要联系贵公司哪位?"
..
20xx-xx-xx hh:mm:ss.sss | DEBUG | __main__:<module>:30 -
reply: 您好,您需要联系中科曙光售前咨询平台,您可以通过访问官网,根据您所在地地址联系平台人员,
或者点击人工客服进行咨询,或者拨打中科曙光服务热线400-810-0466联系人工进行咨询。
如果您需要购买机架式服务器,I620--G30,请与平台人员联系。,
ref: ['train.json', 'FAQ_clean.xls', 'train.json'
```
# 🛠️ 未来计划
1. 支持网络检索融合进工作流
2. 更先进的llm微调与支持(llama3)
3. vllm加速
4. 支持多个本地llm及远程llm调用
5. 支持多模态图文问答
# 🛠️ 相关问题
```Could not load library with AVX2 support due to:ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'") ```问题修复:
找到安装faiss位置
```
import faiss
print(faiss.__file__)
# /.../python3.10/site-packages/faiss/__init__.py
```
添加软链接
``` ```
pip install faiss-1.7.2_dtk24.04_gitb7348e7df780-py3-none-any.whl # cd your_python_path/site-packages/faiss
pip install -r requirements.txt cd /.../python3.10/site-packages/faiss/
``` ln -s swigfaiss.py swigfaiss_avx2.py
\ No newline at end of file ```
from .llm_service import ChatAgent # noqa E401
from .llm_service import ErrorCode # noqa E401 from .llm_service import ErrorCode # noqa E401
from .llm_service import FeatureDataBase # noqa E401 from .llm_service import FeatureDataBase # noqa E401
from .llm_service import LLMInference # noqa E401
from .llm_service import llm_inference # noqa E401
from .llm_service import Worker # noqa E401 from .llm_service import Worker # noqa E401
from .llm_service import DocumentName # noqa E401 from .llm_service import DocumentName # noqa E401
from .llm_service import DocumentProcessor # noqa E401 from .llm_service import DocumentProcessor # noqa E401
from .llm_service import rag_retrieve # noqa E401 \ No newline at end of file
\ No newline at end of file
import json
import requests
import argparse
import re
'''
使用示例:
公共知识库检索:python client.py --action query --query '问题'
私有知识库检索:python client.py --action query --query '问题' --user_id 'user_id'
'''
base_url = 'http://127.0.0.1:8000/%s'
def query(query, user_id=None):
url = base_url % 'work'
try:
header = {'Content-Type': 'application/json'}
# Add history to data
data = {
'query': query,
'history': []
}
if user_id:
data['user_id'] = user_id
resp = requests.post(url,
headers=header,
data=json.dumps(data),
timeout=300)
if resp.status_code != 200:
raise Exception(str((resp.status_code, resp.reason)))
return resp.json()['reply'], resp.json()['references']
except Exception as e:
print(str(e))
return ''
def get_streaming_response(response: requests.Response):
for chunk in response.iter_lines(chunk_size=1024, decode_unicode=False,
delimiter=b"\0"):
if chunk:
pattern = re.compile(rb'data: "(\\u[0-9a-fA-F]{4})"')
matches = pattern.findall(chunk)
decoded_data = []
for match in matches:
hex_value = match[2:].decode('ascii')
char = chr(int(hex_value, 16))
decoded_data.append(char)
print(char, end="", flush=True)
def stream_query(query):
url = base_url % 'stream'
try:
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
# Add history to data
data = {
'query': query,
'history': []
}
resp = requests.get(url,
headers=headers,
data=json.dumps(data),
timeout=300,
verify=False,
stream=True)
get_streaming_response(resp)
except Exception as e:
print(str(e))
def parse_args():
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--query',
default='your query',
help='')
parser.add_argument('--user_id', default='')
parser.add_argument('--stream', action='store_true')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
if args.stream:
stream_query(args.query, args.user_id)
else:
reply, ref = query(args.query, args.user_id)
print('reply: {} \nref: {} '.format(reply, ref))
[default] [default]
work_dir=/path/to/your/ai/work_dir work_dir = /path/to/your/ai/work_dir
bind_port=8888 bind_port = 8000
mem_threshold=50 use_template = False
dcu_threshold=100 output_format = True
[feature_database] [feature_database]
reject_throttle=0.82 reject_throttle=0.61
embedding_model_path=/path/to/your/text2vec-large-chinese embedding_model_path=/path/to/your/text2vec-large-chinese
reranker_model_path=/path/to/your/bce-reranker-base_v1 reranker_model_path=/path/to/your/bce-reranker-base_v1
[llm] [model]
local_llm_path=/home/llama3/Llama3-8B-Chinese-Chat llm_service_address = http://127.0.0.1:8001
use_vllm=True local_service_address = http://127.0.0.1:8002
stream_chat=True cls_model_path = /path/of/classification
tensor_parallel_size=1 llm_model = /path/to/your/Llama3-8B-Chinese-Chat/
\ No newline at end of file local_model = /path/to/your/Finetune/
max_input_length = 1400
\ No newline at end of file
import os
import nltk
nltk.data.path.append('/home/zhangwq/project/whl/nltk/nltk_data-gh-pages/nltk_data')
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['NLTK_DATA'] = '/home/zhangwq/project/whl/nltk/nltk_data-gh-pages/nltk_data'
import base64
import argparse
import uuid
import re
import io
from PIL import Image
from IPython.display import HTML, display
from langchain_experimental.open_clip import OpenCLIPEmbeddings
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import FAISS
from unstructured.partition.pdf import partition_pdf
from langchain_core.messages import HumanMessage
from langchain_community.chat_models import ChatVertexAI
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableLambda
from loguru import logger
def plt_img_base64(img_base64):
# Create an HTML img tag with the base64 string as the source
image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
# Display the image by rendering the HTML
display(HTML(image_html))
def multi_modal_rag_chain(retriever):
"""
Multi-modal RAG chain
"""
# Multi-modal LLM
model = ChatVertexAI(
temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024
)
# RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| model
| StrOutputParser()
)
return chain
def img_prompt_func(data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = []
# Adding the text for analysis
text_message = {
"type": "text",
"text": (
"You are an AI scientist tasking with providing factual answers.\n"
"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
"Use this information to provide answers related to the user question. \n"
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
return [HumanMessage(content=messages)]
def split_image_text_types(docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
doc = resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(doc)
if len(b64_images) > 0:
return {"images": b64_images[:1], "texts": []}
return {"images": b64_images, "texts": texts}
def resize_base64_image(base64_string, size=(128, 128)):
"""
Resize an image encoded as a Base64 string
"""
# Decode the Base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# Resize the image
resized_img = img.resize(size, Image.LANCZOS)
# Save the resized image to a bytes buffer
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def is_image_data(b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xFF\xD8\xFF": "jpg",
b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def looks_like_base64(sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def create_multi_vector_retriever(
vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
"""
Create retriever that indexes summaries, but returns raw images or texts
"""
# Initialize the storage layer
store = InMemoryStore()
id_key = "doc_id"
# Create the multi-vector retriever
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Helper function to add documents to the vectorstore and docstore
def add_documents(retriever, doc_summaries, doc_contents):
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(doc_summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
# Add texts, tables, and images
# Check that text_summaries is not empty before adding
if text_summaries:
add_documents(retriever, text_summaries, texts)
# Check that table_summaries is not empty before adding
if table_summaries:
add_documents(retriever, table_summaries, tables)
# Check that image_summaries is not empty before adding
if image_summaries:
add_documents(retriever, image_summaries, images)
return retriever
def extract_elements_from_pdf(file_path: str, image_output_dir_path: str):
pdf_list = [os.path.join(file_path, file) for file in os.listdir(file_path) if file.endswith('.pdf')]
tables = []
texts = []
raw_pdf_elements = partition_pdf(
filename=pdf_list[0],
extract_images_in_pdf=True,
infer_table_structure=True,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=image_output_dir_path,
)
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
tables.append(str(element))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
texts.append(str(element))
return texts, tables
class Summary:
def __init__(self):
pass
def encode_image(self, image_path):
"""Getting the base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_summarize(self, img_base64, prompt):
"""Make image summary"""
model = ChatVertexAI(model_name="gemini-pro-vision", max_output_tokens=1024)
msg = model(
[
HumanMessage(
content=[
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
]
)
]
)
return msg.content
def generate_img_summaries(self, path):
"""
Generate summaries and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
# Prompt
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""
# Apply to images
for img_file in sorted(os.listdir(path)):
if img_file.endswith(".jpg"):
img_path = os.path.join(path, img_file)
base64_image = self.encode_image(img_path)
img_base64_list.append(base64_image)
image_summaries.append(self.image_summarize(base64_image, prompt))
return img_base64_list, image_summaries
def generate_text_summaries(self, texts, tables):
text_summaries = texts
table_summaries = tables
return text_summaries, table_summaries
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--file_path',
type=str,
default='/home/zhangwq/data/art_test/pdf',
help='')
parser.add_argument(
'--image_output_dir_path',
default='/home/zhangwq/data/art_test',
help='')
parser.add_argument(
'--query',
default='compare and contrast between mistral and llama2 across benchmarks and explain the reasoning in detail',
help='')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
summary = Summary()
texts, tables = extract_elements_from_pdf(file_path=args.file_path,
image_output_dir_path=args.image_output_dir_path)
text_summaries, table_summaries = summary.generate_text_summaries(texts, tables)
img_base64_list, image_summaries = summary.generate_img_summaries(args.file_path)
embeddings = OpenCLIPEmbeddings(
model_name="/home/zhangwq/model/CLIP_VIT", checkpoint="laion2b_s34b_b88k")
embeddings.client = embeddings.client.half()
vectorstore = FAISS(collection_name="mm_rag_mistral",
embedding_function=embeddings)
# Create retriever
retriever = create_multi_vector_retriever(
vectorstore,
text_summaries,
texts,
table_summaries,
tables,
image_summaries,
img_base64_list,
)
chain_multimodal_rag = multi_modal_rag_chain(retriever)
docs = retriever.get_relevant_documents(args.query, limit=1)
logger.info(docs[0])
chain_multimodal_rag.invoke(args.query)
\ No newline at end of file
import os
import time
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'
def infer_hf_chatglm(model_path, prompt): prompt = ''
'''transformers 推理 chatglm2''' model_path = ''
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto").half().cuda()
model = model.eval()
start_time = time.time()
generated_text, _ = model.chat(tokenizer, prompt, history=[])
print("chat time ", time.time()- start_time)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
def infer_hf_llama3(model_path, prompt):
'''transformers 推理 llama3'''
input_query = {"role": "user", "content": prompt}
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto")
input_ids = tokenizer.apply_chat_template(
[input_query,], add_generation_prompt=True, return_tensors="pt").to(model.device)
outputs = model.generate( sampling_params = SamplingParams(temperature=1, top_p=0.95)
input_ids, llm = LLM(model=model_path,
max_new_tokens=512, trust_remote_code=True,
do_sample=True, enforce_eager=True,
temperature=1, tensor_parallel_size=1)
top_p=0.95,
)
response = outputs[0][input_ids.shape[-1]:] outputs = llm.generate(prompt, sampling_params)
generated_text = tokenizer.decode(response, skip_special_tokens=True) for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
def infer_vllm_llama3(model_path, message, tp_size=1, max_model_len=1024):
'''vllm 推理 llama3'''
tokenizer = AutoTokenizer.from_pretrained(model_path)
messages = [{"role": "user", "content": message}]
print(f"Prompt: {messages!r}")
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
stop_token_ids=[tokenizer.eos_token_id])
llm = LLM(model=model_path,
max_model_len=max_model_len,
trust_remote_code=True,
enforce_eager=True,
dtype="float16",
tensor_parallel_size=tp_size)
# generate answer
start_time = time.time()
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
print("total infer time", time.time() - start_time)
# results
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
def infer_vllm_chatglm(model_path, message, tp_size=1):
'''vllm 推理 chatglm2'''
sampling_params = SamplingParams(temperature=1.0,
top_p=0.9,
max_tokens=1024)
llm = LLM(model=model_path,
trust_remote_code=True,
enforce_eager=True,
dtype="float16",
tensor_parallel_size=tp_size)
# generate answer
print(f"chatglm2 Prompt: {message!r}")
outputs = llm.generate(message, sampling_params=sampling_params)
# results
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='')
parser.add_argument('--query', default="DCU是什么?", help='提问的问题.')
parser.add_argument('--use_hf', action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
is_llama = True if "llama" in args.model_path else False
print("Is llama", is_llama)
if args.use_hf:
# transformers
if is_llama:
infer_hf_llama3(args.model_path, args.query)
else:
infer_hf_chatglm(args.model_path, args.query)
else:
# vllm
if is_llama:
infer_vllm_llama3(args.model_path, args.query)
else:
infer_vllm_chatglm(args.model_path, args.query)
from .feature_database import FeatureDataBase, DocumentProcessor, DocumentName # noqa E401 from .feature_database import FeatureDataBase, DocumentProcessor, DocumentName # noqa E401
from .helper import TaskCode, ErrorCode, LogManager # noqa E401 from .helper import TaskCode, ErrorCode, LogManager # noqa E401
from .inferencer import LLMInference, InferenceWrapper # noqa E401 from .http_client import OpenAPIClient, ClassifyClient # noqa E401
from .retriever import CacheRetriever, Retriever, rag_retrieve# noqa E401 from .worker import Worker # noqa E401
from .worker import Worker, ChatAgent # noqa E401 \ No newline at end of file
\ No newline at end of file
...@@ -8,7 +8,7 @@ import hashlib ...@@ -8,7 +8,7 @@ import hashlib
import textract import textract
import shutil import shutil
import configparser import configparser
import json
from multiprocessing import Pool from multiprocessing import Pool
from typing import List from typing import List
from loguru import logger from loguru import logger
...@@ -18,7 +18,17 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter ...@@ -18,7 +18,17 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.faiss import FAISS from langchain.vectorstores.faiss import FAISS
from torch.cuda import empty_cache from torch.cuda import empty_cache
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from .retriever import CacheRetriever, Retriever from elastic_keywords_search import ElasticKeywordsSearch
from retriever import Retriever
def check_envs(args):
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
class DocumentName: class DocumentName:
...@@ -143,14 +153,7 @@ class DocumentProcessor: ...@@ -143,14 +153,7 @@ class DocumentProcessor:
text = re.sub(r'\n\s*\n', '\n\n', text) text = re.sub(r'\n\s*\n', '\n\n', text)
elif file_type == 'excel': elif file_type == 'excel':
text = [] text += self.read_excel(filepath)
df = pd.read_excel(filepath, header=None)
for row in df.index.values:
doc = dict()
doc['Que'] = df.iloc[row, 0]
doc['Ans'] = df.iloc[row, 1]
text.append(str(doc))
# text += self.read_excel(filepath)
elif file_type == 'word' or file_type == 'ppt': elif file_type == 'word' or file_type == 'ppt':
# https://stackoverflow.com/questions/36001482/read-doc-file-with-python # https://stackoverflow.com/questions/36001482/read-doc-file-with-python
...@@ -177,6 +180,17 @@ class DocumentProcessor: ...@@ -177,6 +180,17 @@ class DocumentProcessor:
return text, None return text, None
def read_excel(self, filepath: str):
table = None
if filepath.endswith('.csv'):
table = pd.read_csv(filepath)
else:
table = pd.read_excel(filepath)
if table is None:
return ''
json_text = table.dropna(axis=1).to_json(force_ascii=False)
return json_text
def read_pdf(self, filepath: str): def read_pdf(self, filepath: str):
# load pdf and serialize table # load pdf and serialize table
...@@ -231,7 +245,7 @@ class FeatureDataBase: ...@@ -231,7 +245,7 @@ class FeatureDataBase:
def __init__(self, def __init__(self,
embeddings: HuggingFaceEmbeddings, embeddings: HuggingFaceEmbeddings,
reranker: BCERerank, reranker: BCERerank,
reject_throttle: float) -> None: reject_throttle=-1) -> None:
# logger.debug('loading text2vec model..') # logger.debug('loading text2vec model..')
self.embeddings = embeddings self.embeddings = embeddings
...@@ -242,7 +256,7 @@ class FeatureDataBase: ...@@ -242,7 +256,7 @@ class FeatureDataBase:
self.reject_throttle = reject_throttle if reject_throttle else -1 self.reject_throttle = reject_throttle if reject_throttle else -1
self.text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=768, chunk_overlap=32) chunk_size=1068, chunk_overlap=32)
def get_documents(self, text, file): def get_documents(self, text, file):
# if len(text) <= 1: # if len(text) <= 1:
...@@ -256,12 +270,17 @@ class FeatureDataBase: ...@@ -256,12 +270,17 @@ class FeatureDataBase:
documents.append(chunk) documents.append(chunk)
return documents return documents
def register_response(self, files: list, work_dir: str, file_opr: DocumentProcessor): def build_database(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None):
feature_dir = os.path.join(work_dir, 'db_response') feature_dir = os.path.join(work_dir, 'db_response')
if not os.path.exists(feature_dir): if not os.path.exists(feature_dir):
os.makedirs(feature_dir) os.makedirs(feature_dir)
documents = [] documents = []
texts_for_es = []
metadatas_for_es = []
ids_for_es = []
for i, file in enumerate(files): for i, file in enumerate(files):
if not file.status: if not file.status:
continue continue
...@@ -273,58 +292,31 @@ class FeatureDataBase: ...@@ -273,58 +292,31 @@ class FeatureDataBase:
file.message = str(error) file.message = str(error)
continue continue
file.message = str(text[0]) file.message = str(text[0])
# file.message = str(len(text))
# logger.info('{} content length {}'.format( texts_for_es.append(text[0])
# file._category, len(text))) metadatas_for_es.append({'source': file.basename, 'read': file.origin_path})
ids_for_es.append(str(i))
document = self.get_documents(text, file) document = self.get_documents(text, file)
documents += document documents += document
logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents' logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents'
.format(i + 1, len(files), file.basename, len(document))) .format(i + 1, len(files), file.basename, len(document)))
logger.debug('Positive pipeline register {} documents into database...'.format(len(documents))) if elastic_search is not None:
time_before_register = time.time() logger.debug('ES database pipeline register {} documents into database...'.format(len(texts_for_es)))
vs = FAISS.from_documents(documents, self.embeddings) es_time_before_register = time.time()
vs.save_local(feature_dir) elastic_search.add_texts(texts_for_es, metadatas=metadatas_for_es, ids=ids_for_es)
es_time_after_register = time.time()
time_after_register = time.time() logger.debug('ES database pipeline take time: {} '.format(es_time_after_register - es_time_before_register))
logger.debug('Positive pipeline take time: {} '.format(time_after_register - time_before_register))
def register_reject(self, files: list, work_dir: str, file_opr: DocumentProcessor): logger.debug('Vector database pipeline register {} documents into database...'.format(len(documents)))
feature_dir = os.path.join(work_dir, 'db_reject')
if not os.path.exists(feature_dir):
os.makedirs(feature_dir)
documents = []
for i, file in enumerate(files):
if not file.state:
continue
text, error = file_opr.read(file.copypath)
if len(text) < 1:
continue
if error is not None:
continue
document = self.get_documents(text, file)
documents += document
logger.debug('Negative pipeline {}/{}.. register 《{}》 and split {} documents'
.format(i + 1, len(files), file.basename, len(document)))
if len(documents) < 1:
return
logger.debug('Negative pipeline register {} documents into database...'.format(len(documents)))
time_before_register = time.time()
ve_time_before_register = time.time()
vs = FAISS.from_documents(documents, self.embeddings) vs = FAISS.from_documents(documents, self.embeddings)
vs.save_local(feature_dir) vs.save_local(feature_dir)
ve_time_after_register = time.time()
time_after_register = time.time() logger.debug('Vector database pipeline take time: {} '.format(ve_time_after_register - ve_time_before_register))
logger.debug('Negative pipeline take time: {} '.format(time_after_register - time_before_register))
def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor): def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor):
...@@ -343,7 +335,7 @@ class FeatureDataBase: ...@@ -343,7 +335,7 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = 'skip image' file.message = 'skip image'
elif file._category in ['pdf', 'word', 'ppt', 'html']: elif file._category in ['pdf', 'word', 'ppt', 'html', 'excel']:
# read pdf/word/excel file and save to text format # read pdf/word/excel file and save to text format
md5 = file_opr.md5(file.origin_path) md5 = file_opr.md5(file.origin_path)
file.copy_path = os.path.join(preproc_dir, file.copy_path = os.path.join(preproc_dir,
...@@ -363,7 +355,7 @@ class FeatureDataBase: ...@@ -363,7 +355,7 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = str(e) file.message = str(e)
elif file._category in ['json', 'excel']: elif file._category in ['json']:
file.status = True file.status = True
file.copy_path = file.origin_path file.copy_path = file.origin_path
file.message = 'preprocessed' file.message = 'preprocessed'
...@@ -385,11 +377,10 @@ class FeatureDataBase: ...@@ -385,11 +377,10 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = 'read error' file.message = 'read error'
def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor): def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None):
self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr) self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr)
self.register_response(files=files, work_dir=work_dir, file_opr=file_opr) self.build_database(files=files, work_dir=work_dir, file_opr=file_opr, elastic_search=elastic_search)
# self.register_reject(files=files, work_dir=work_dir, file_opr=file_opr)
def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor): def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor):
...@@ -408,6 +399,7 @@ class FeatureDataBase: ...@@ -408,6 +399,7 @@ class FeatureDataBase:
file.status = False file.status = False
file.message = str(error) file.message = str(error)
continue continue
logger.info(str(len(text)), text, str(text[0]))
file.message = str(text[0]) file.message = str(text[0])
# file.message = str(len(text)) # file.message = str(len(text))
...@@ -416,9 +408,13 @@ class FeatureDataBase: ...@@ -416,9 +408,13 @@ class FeatureDataBase:
documents += self.get_documents(text, file) documents += self.get_documents(text, file)
vs = FAISS.from_documents(documents, self.embeddings) if documents:
faiss.merge_from(vs) vs = FAISS.from_documents(documents, self.embeddings)
faiss.save_local(feature_dir) if faiss:
faiss.merge_from(vs)
faiss.save_local(feature_dir)
else:
vs.save_local(feature_dir)
def test_reject(retriever: Retriever): def test_reject(retriever: Retriever):
...@@ -461,17 +457,21 @@ def parse_args(): ...@@ -461,17 +457,21 @@ def parse_args():
description='Feature store for processing directories.') description='Feature store for processing directories.')
parser.add_argument('--work_dir', parser.add_argument('--work_dir',
type=str, type=str,
default='/home/chat_demo/work_dir/', default='',
help='自定义.') help='自定义.')
parser.add_argument( parser.add_argument(
'--repo_dir', '--repo_dir',
type=str, type=str,
default='/home/chat_demo/work_dir/jifu/original', default='',
help='需要读取的文件目录.') help='需要读取的文件目录.')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/home/chat_demo/config.ini', default='./ai/rag/config.ini',
help='config目录') help='config目录')
parser.add_argument(
'--DCU_ID',
default=[7],
help='设置DCU')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -482,53 +482,45 @@ if __name__ == '__main__': ...@@ -482,53 +482,45 @@ if __name__ == '__main__':
log_file_path = os.path.join(args.work_dir, 'application.log') log_file_path = os.path.join(args.work_dir, 'application.log')
logger.add(log_file_path, rotation='10MB', compression='zip') logger.add(log_file_path, rotation='10MB', compression='zip')
check_envs(args)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
embedding_model_path = config['feature_database']['embedding_model_path'] # only init vector retriever
reranker_model_path = config['feature_database']['reranker_model_path'] retriever = Retriever(config)
reject_throttle = float(config['feature_database']['reject_throttle']) fs_init = FeatureDataBase(embeddings=retriever.embeddings,
reranker=retriever.reranker)
cache = CacheRetriever(embedding_model_path=embedding_model_path, # init es retriever, drop_old means build new one or updata the 'index_name'
reranker_model_path=reranker_model_path) es_url = config.get('rag', 'es_url')
fs_init = FeatureDataBase(embeddings=cache.embeddings, index_name = config.get('rag', 'index_name')
reranker=cache.reranker,
reject_throttle=reject_throttle) elastic_search = ElasticKeywordsSearch(
elasticsearch_url=es_url,
index_name=index_name,
drop_old=True)
# walk all files in repo dir # walk all files in repo dir
file_opr = DocumentProcessor() file_opr = DocumentProcessor()
files = file_opr.scan_directory(repo_dir=args.repo_dir) files = file_opr.scan_directory(repo_dir=args.repo_dir)
fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr) fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr, elastic_search=elastic_search)
file_opr.summarize(files) file_opr.summarize(files)
del fs_init # del fs_init
retriever = cache.get(reject_throttle=reject_throttle,
work_dir=args.work_dir)
# with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f: # with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f:
# positive_sample = json.load(f) # positive_sample = json.load(f)
# with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f: # with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f:
# negative_sample = json.load(f) # negative_sample = json.load(f)
#
with open(os.path.join(args.work_dir, 'jifu', 'positive.txt'), 'r', encoding='utf-8') as file: # with open(os.path.join(args.work_dir, 'sample', 'positive.txt'), 'r', encoding='utf-8') as file:
positive_sample = [] # positive_sample = []
for line in file: # for line in file:
positive_sample.append(line.strip()) # positive_sample.append(line.strip())
#
with open(os.path.join(args.work_dir, 'jifu', 'negative.txt'), 'r', encoding='utf-8') as file: # with open(os.path.join(args.work_dir, 'sample', 'negative.txt'), 'r', encoding='utf-8') as file:
negative_sample = [] # negative_sample = []
for line in file: # for line in file:
negative_sample.append(line.strip()) # negative_sample.append(line.strip())
#
reject_throttle = retriever.update_throttle(work_dir=args.work_dir, # test_reject(retriever)
config_path=args.config_path,
positive_sample=positive_sample,
negative_sample=negative_sample)
cache.pop('default')
# test
retriever = cache.get(reject_throttle=reject_throttle,
work_dir=args.work_dir)
test_reject(retriever)
import time
import random
import argparse
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.firefox.options import Options
from concurrent.futures import ThreadPoolExecutor, as_completed
from selenium.common.exceptions import TimeoutException, WebDriverException, NoSuchElementException
from bs4 import BeautifulSoup
from loguru import logger
class FetchBingResults:
def __init__(self,query):
self.query = query
def get_driver(self):
options = Options()
options.set_preference('permissions.default.image', 2) # 2表示禁止加载图片
options.add_argument('--headless') # 使用无头模式,不显示浏览器窗口
options.add_argument('--no-sandbox') # 禁用沙箱机制
options.add_argument('--disable-gpu') # 禁用GPU硬件加速
options.add_argument('--disable-dev-shm-usage') # 禁用 /dev/shm 的共享内存使用
options.add_argument('--user-agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"') # 设置用户代理
options.set_preference("intl.accept_languages", "en-US,en") # 设置语言
driver = webdriver.Firefox(options=options)
return driver
def fetch_bing_results(self, num = 1):
retries = 3
for _ in range(retries):
driver = self.get_driver()
driver.get(f'https://www.bing.com/search?q={self.query}')
try:
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.XPATH, '//li[@class="b_algo"]'))
)
html_content = driver.page_source
soup = BeautifulSoup(html_content, 'html.parser')
b_results = soup.find('ol', {'id': 'b_results'})
results = list(b_results.find_all('li', class_='b_algo'))
search_results = []
# 创建线程池
with ThreadPoolExecutor(max_workers = num) as executor:
# 创建字典
future_to_result = {
executor.submit(self.fetch_article_content, result.find('a')['href']): result
for result in results[:num]
}
for future in as_completed(future_to_result):
result = future_to_result[future]
try:
content, current_url = future.result()
title = result.find('h2').text
return content[:1000]
# search_results.append({'title': title, 'content': content, 'link': current_url})
except Exception as exc:
logger.error(f'Generated an exception: {exc}')
# return search_results
except (TimeoutException, WebDriverException) as e:
logger.error(f"Attempt {_ + 1} failed: {str(e)}")
time.sleep(random.uniform(1, 3)) # 等待一段随机时间后重试
finally:
driver.quit()
logger.error("All retries failed.")
return ''
def fetch_article_content(self,link):
retries = 3
for _ in range(retries):
driver = self.get_driver()
driver.get(link)
try:
try:
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.LINK_TEXT, 'Please click here if the page does not redirect automatically ...'))
).click()
except TimeoutException:
logger.error("No redirection link found.")
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, 'body'))
)
# 执行 JavaScript 以确保所有动态内容加载完成
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
# 使用新的动态等待方法来检查页面内容
WebDriverWait(driver, 10).until(
lambda d: d.execute_script('return document.readyState') == 'complete'
)
article_page_source = driver.page_source
article_soup = BeautifulSoup(article_page_source, 'html.parser')
# 提取页面内容
content = article_soup.get_text(strip=True)
# 获取当前页面的URL
current_url = driver.current_url
return content,current_url
except (TimeoutException, WebDriverException, NoSuchElementException) as e:
logger.error(f"Attempt {_ + 1} failed: {str(e)}")
time.sleep(random.uniform(1, 3)) # 等待一段随机时间后重试
finally:
driver.quit()
return None, link
def __del__(self):
self.driver.quit()
def parse_args():
parser = argparse.ArgumentParser(
description='')
parser.add_argument(
'--query',
default='介绍下曙光DCU',
help='提问的问题.')
args = parser.parse_args()
return args
def main():
args = parse_args()
fetch_bing = FetchBingResults(args)
results = fetch_bing.fetch_bing_results()
print(results)
if __name__ == "__main__":
main()
\ No newline at end of file
...@@ -30,7 +30,7 @@ class ErrorCode(Enum): ...@@ -30,7 +30,7 @@ class ErrorCode(Enum):
SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota' SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota'
NOT_FIND_RELATED_DOCS = 16, 'No relevant documents found, the following answer is generated directly by LLM.' NOT_FIND_RELATED_DOCS = 16, 'No relevant documents found, the following answer is generated directly by LLM.'
NON_COMPLIANCE_QUESTION = 17, 'Non-compliance question, refusing to answer.' NON_COMPLIANCE_QUESTION = 17, 'Non-compliance question, refusing to answer.'
SCORE_ERROR = 18, 'Get score error.' NO_WEB_SEARCH_RESULT = 18, 'Can not fetch result from web.'
def __new__(cls, value, description): def __new__(cls, value, description):
"""Create new instance of ErrorCode.""" """Create new instance of ErrorCode."""
...@@ -68,4 +68,4 @@ class LogManager: ...@@ -68,4 +68,4 @@ class LogManager:
file.write(f'{operation}: {outcome}\n') file.write(f'{operation}: {outcome}\n')
file.write('\n') file.write('\n')
except Exception as e: except Exception as e:
print(e) print(e)
\ No newline at end of file
import os
import time
import json
import httpx
import configparser
import torch
import numpy as np
from loguru import logger
from BCEmbedding.tools.langchain import BCERerank
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from sklearn.metrics import precision_recall_curve
from transformers import BertForSequenceClassification, BertTokenizer
def build_history_messages(prompt, history, system: str = None):
history_messages = []
if system is not None and len(system) > 0:
history_messages.append({'role': 'system', 'content': system})
for item in history:
history_messages.append({'role': 'user', 'content': item[0]})
history_messages.append({'role': 'assistant', 'content': item[1]})
history_messages.append({'role': 'user', 'content': prompt})
return history_messages
class OpenAPIClient:
def __init__(self, url: str, model_name):
self.url = '%s/v1/chat/completions' % url
self.model_name = model_name
async def get_streaming_response(self, headers, data):
async with httpx.AsyncClient() as client:
async with client.stream("POST", self.url, json=data, headers=headers, timeout=300) as response:
async for line in response.aiter_lines():
if not line or 'DONE' in line:
continue
try:
result = json.loads(line.split('data:')[1])
output = result['choices'][0]['delta'].get('content')
except Exception as e:
logger.error('Model response parse failed:', e)
raise Exception('Model response parse failed.')
if not output:
continue
yield output
async def get_response(self, headers, data):
async with httpx.AsyncClient() as client:
resp = await client.post(self.url, json=data, headers=headers, timeout=300)
try:
result = json.loads(resp.content.decode("utf-8"))
output = result['choices'][0]['message']['content']
except Exception as e:
logger.error('Model response parse failed:', e)
raise Exception('Model response parse failed.')
return output
async def chat(self, prompt: str, history=[], stream=False):
header = {'Content-Type': 'application/json'}
# Add history to data
data = {
"model": self.model_name,
"messages": build_history_messages(prompt, history),
"stream": stream
}
logger.info("Request openapi param: {}".format(data))
if stream:
return self.get_streaming_response(header, data)
else:
return await self.get_response(header, data)
class ClassifyModel:
def __init__(self, model_path, dcu_id):
logger.info("Starting initial bert class model")
self.cls_model = BertForSequenceClassification.from_pretrained(model_path).float().cuda()
self.cls_model.load_state_dict(torch.load(os.path.join(model_path, 'bert_cls_model.pth')))
self.cls_model.eval()
self.cls_tokenizer = BertTokenizer.from_pretrained(model_path)
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_id
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_id}")
def classfication(self, sentence):
inputs = self.cls_tokenizer(
sentence,
max_length=512,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to('cuda')
with torch.no_grad():
outputs = self.cls_model(**inputs)
logits = outputs[0]
score = torch.max(logits.data, 1)[1].tolist()
logger.info("分类结果: {}, {}".format(score[0], sentence))
return float(score[0])
class CacheRetriever:
def __init__(self, embedding_model_path: str, reranker_model_path: str, max_len: int = 4):
self.cache = dict()
self.max_len = max_len
# load text2vec and rerank model
logger.info('loading test2vec and rerank models')
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cuda'},
encode_kwargs={
'batch_size': 1,
'normalize_embeddings': True
})
# half
self.embeddings.client = self.embeddings.client.half()
reranker_args = {
'model': reranker_model_path,
'top_n': 3,
'device': 'cuda',
'use_fp16': True
}
self.reranker = BCERerank(**reranker_args)
def get(self,
reject_throttle: float,
fs_id: str = 'default',
work_dir='workdir'
):
if fs_id in self.cache:
self.cache[fs_id]['time'] = time.time()
return self.cache[fs_id]['retriever']
if len(self.cache) >= self.max_len:
# drop the oldest one
del_key = None
min_time = time.time()
for key, value in self.cache.items():
cur_time = value['time']
if cur_time < min_time:
min_time = cur_time
del_key = key
if del_key is not None:
del_value = self.cache[del_key]
self.cache.pop(del_key)
del del_value['retriever']
retriever = Retriever(embeddings=self.embeddings,
reranker=self.reranker,
work_dir=work_dir,
reject_throttle=reject_throttle)
self.cache[fs_id] = {'retriever': retriever, 'time': time.time()}
return retriever
def pop(self, fs_id: str):
if fs_id not in self.cache:
return
del_value = self.cache[fs_id]
self.cache.pop(fs_id)
# manually free memory
del del_value
class Retriever:
def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None:
self.reject_throttle = reject_throttle
self.rejecter = None
self.retriever = None
self.compression_retriever = None
self.embeddings = embeddings
self.reranker = reranker
self.rejecter = FAISS.load_local(
os.path.join(work_dir, 'db_response'),
embeddings=embeddings,
allow_dangerous_deserialization=True)
self.vector_store = FAISS.load_local(
os.path.join(work_dir, 'db_response'),
embeddings=embeddings,
allow_dangerous_deserialization=True,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
self.retriever = self.vector_store.as_retriever(
search_type='similarity',
search_kwargs={
'score_threshold': self.reject_throttle,
'k': 30
}
)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=self.retriever)
if self.rejecter is None:
logger.warning('rejecter is None')
if self.retriever is None:
logger.warning('retriever is None')
def is_relative(self, sample, k=30, disable_throttle=False):
"""If no search results below the threshold can be found from the
database, reject this query."""
if self.rejecter is None:
return False, []
if disable_throttle:
# for searching throttle during update sample
docs_with_score = self.rejecter.similarity_search_with_relevance_scores(
sample, k=1)
if len(docs_with_score) < 1:
return False, docs_with_score
return True, docs_with_score
else:
# for retrieve result
# if no chunk passed the throttle, give the max
docs_with_score = self.rejecter.similarity_search_with_relevance_scores(
sample, k=k)
ret = []
max_score = -1
top1 = None
for (doc, score) in docs_with_score:
if score >= self.reject_throttle:
ret.append(doc)
if score > max_score:
max_score = score
top1 = (doc, score)
relative = True if len(ret) > 0 else False
return relative, [top1]
def update_throttle(self,
work_dir: str,
config_path: str = 'config.ini',
positive_sample=[],
negative_sample=[]):
import matplotlib.pyplot as plt
"""Update reject throttle based on positive and negative examples."""
if len(positive_sample) == 0 or len(negative_sample) == 0:
raise Exception('positive and negative samples cat not be empty.')
all_samples = positive_sample + negative_sample
predictions = []
for sample in all_samples:
self.reject_throttle = -1
_, docs = self.is_relative(sample=sample,
disable_throttle=True)
score = docs[0][1]
predictions.append(max(0, score))
labels = [1 for _ in range(len(positive_sample))
] + [0 for _ in range(len(negative_sample))]
precision, recall, thresholds = precision_recall_curve(
labels, predictions)
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, label='Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='best')
plt.grid(True)
plt.savefig(os.path.join(work_dir, 'precision_recall_curve.png'), format='png')
plt.close()
logger.debug("Figure have been saved!")
thresholds = np.append(thresholds, 1)
max_precision = np.max(precision)
indices_with_max_precision = np.where(precision == max_precision)
optimal_recall = recall[indices_with_max_precision[0][0]]
optimal_threshold = thresholds[indices_with_max_precision[0][0]]
logger.debug(f"Optimal threshold with the highest recall under the highest precision is: {optimal_threshold}")
logger.debug(f"The corresponding precision is: {max_precision}")
logger.debug(f"The corresponding recall is: {optimal_recall}")
config = configparser.ConfigParser()
config.read(config_path)
config['feature_database']['reject_throttle'] = str(optimal_threshold)
with open(config_path, 'w') as configfile:
config.write(configfile)
logger.info(
f'Update optimal threshold: {optimal_threshold} to {config_path}' # noqa E501
)
return optimal_threshold
def query(self,
question: str,
):
time_1 = time.time()
if question is None or len(question) < 1:
return None, None, []
if len(question) > 512:
logger.warning('input too long, truncate to 512')
question = question[0:512]
chunks = []
references = []
relative, docs = self.is_relative(sample=question)
if relative:
docs = self.compression_retriever.get_relevant_documents(question)
for doc in docs:
doc = [doc.page_content]
chunks.append(doc)
# chunks = [doc.page_content for doc in docs]
references = [doc.metadata['source'] for doc in docs]
time_2 = time.time()
logger.debug('query:{} \nchunks:{} \nreferences:{} \ntimecost:{}'
.format(question, chunks, references, time_2 - time_1))
return chunks, [os.path.basename(r) for r in references]
else:
if len(docs) > 0:
references.append(docs[0][0].metadata['source'])
logger.info('feature database rejected!')
return chunks, references
This diff is collapsed.
import time
import os
import configparser
import argparse
# import torch
import asyncio
import uuid
from typing import AsyncGenerator
from loguru import logger
from aiohttp import web
# from multiprocessing import Value
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
COMMON = { COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline", "<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
...@@ -53,145 +38,4 @@ COMMON = { ...@@ -53,145 +38,4 @@ COMMON = {
"<展板设计选择>": "1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计", "<展板设计选择>": "1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计",
"<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》", "<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》",
"":"", "":"",
} }
\ No newline at end of file
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
# huggingface
logger.info("Starting initial model of Llama - vllm")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# vllm
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
early_stopping=False,
stop_token_ids=[tokenizer.eos_token_id]
)
# vLLM基础配置
args = AsyncEngineArgs(model_path)
args.worker_use_ray = False
args.engine_use_ray = False
args.tokenizer = model_path
args.tensor_parallel_size = tensor_parallel_size
args.trust_remote_code = True
args.enforce_eager = True
args.max_model_len = 1024
args.dtype = 'float16'
# 加载模型
engine = AsyncLLMEngine.from_engine_args(args)
return engine, tokenizer, sampling_params
def llm_inference(args):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
use_vllm = config.getboolean('llm', 'use_vllm')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenizer, sampling_params = init_model(model_path, tensor_parallel_size)
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
logger.info(f"before generate {messages}")
## 1
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
print(text)
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator:
final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs})
assert final_output is not None
return [output.text for output in final_output.outputs]
if stream_chat:
logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results())
text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
# Non-streaming case
final_output = None
async for request_output in results_generator:
# if await request.is_disconnected():
# # Abort the request if the client disconnects.
# await engine.abort(request_id)
# return Response(status_code=499)
final_output = request_output
assert final_output is not None
text = [output.text for output in final_output.outputs]
end = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text})
app = web.Application()
app.add_routes([web.post('/inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
def parse_args():
'''参数'''
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--config_path',
default='../config.ini',
help='config目录')
parser.add_argument(
'--query',
default='写一首诗',
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
type=str,
default='4',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args()
return args
def main():
args = parse_args()
set_envs(args.DCU_ID)
llm_inference(args)
if __name__ == '__main__':
main()
This diff is collapsed.
...@@ -6,37 +6,36 @@ from loguru import logger ...@@ -6,37 +6,36 @@ from loguru import logger
from llm_service import Worker, llm_inference from llm_service import Worker, llm_inference
def set_envs(dcu_ids): def check_envs(args):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
def parse_args(): def parse_args():
"""Parse args.""" """Parse args."""
parser = argparse.ArgumentParser(description='Executor.') parser = argparse.ArgumentParser(description='Executor.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
type=str, default=[1,2,6,7],
default='0', help='设置DCU')
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/path/of/config.ini', default='/path/to/your/ai/config.ini',
type=str, type=str,
help='config.ini路径') help='config.ini路径')
parser.add_argument( parser.add_argument(
'--standalone', '--standalone',
default=False, default=False,
help='部署LLM推理服务') help='部署LLM推理服务.')
parser.add_argument( parser.add_argument(
'--use_vllm', '--accelerate',
default=False, default=False,
type=bool, type=bool,
help='是否启用LLM推理加速' help='LLM推理是否启用加速'
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -55,7 +54,7 @@ def build_reply_text(reply: str, references: list): ...@@ -55,7 +54,7 @@ def build_reply_text(reply: str, references: list):
def reply_workflow(assistant): def reply_workflow(assistant):
queries = ['我们公司想要购买几台测试机,请问需要联系哪位?'] queries = ['你好,我们公司想要购买几台测试机,请问需要联系贵公司哪位?']
for query in queries: for query in queries:
code, reply, references = assistant.produce_response(query=query, code, reply, references = assistant.produce_response(query=query,
history=[], history=[],
...@@ -67,19 +66,19 @@ def run(): ...@@ -67,19 +66,19 @@ def run():
args = parse_args() args = parse_args()
if args.standalone is True: if args.standalone is True:
import time import time
set_envs(args) check_envs(args)
server_ready = Value('i', 0) server_ready = Value('i', 0)
server_process = Process(target=llm_inference, server_process = Process(target=llm_inference,
args=(args.config_path, args=(args.config_path,
len(args.DCU_ID), len(args.DCU_ID),
args.use_vllm, args.accelerate,
server_ready)) server_ready))
server_process.daemon = True server_process.daemon = True
server_process.start() server_process.start()
while True: while True:
if server_ready.value == 0: if server_ready.value == 0:
logger.info('waiting for server to be ready.') logger.info('waiting for server to be ready..')
time.sleep(15) time.sleep(15)
elif server_ready.value == 1: elif server_ready.value == 1:
break break
......
[default]
work_dir = /path/to/your/ai/work_dir
bind_port = 8003
[rag]
embedding_model_path = /path/to/your/text2vec-large-chinese
reranker_model_path = /path/to/your/bce-reranker-base_v1
vector_top_k = 5
es_top_k = 5
es_url = http://10.2.106.50:31920
index_name = dcu_knowledge_base
from __future__ import annotations
import argparse
import uuid
from abc import ABC
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
from loguru import logger
import jieba.analyse
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
if TYPE_CHECKING:
from elasticsearch import Elasticsearch # noqa: F401
def _default_text_mapping() -> Dict:
return {'properties': {'text': {'type': 'text'}}}
DEFAULT_PROMPT = PromptTemplate(
input_variables=['question'],
template="""分析给定Question,提取Question中包含的KeyWords,输出列表形式
Examples:
Question: 达梦公司在过去三年中的流动比率如下:2021年:3.74倍;2020年:2.82倍;2019年:2.05倍。
KeyWords: ['过去三年', '流动比率', '2021', '3.74', '2020', '2.82', '2019', '2.05']
----------------
Question: {question}
KeyWords: """,
)
class ElasticKeywordsSearch(VectorStore, ABC):
def __init__(
self,
elasticsearch_url: str,
index_name: str,
drop_old: Optional[bool] = False,
*,
ssl_verify: Optional[Dict[str, Any]] = None,
llm_chain: Optional[LLMChain] = None,
):
try:
import elasticsearch
except ImportError:
logger.error('Could not import elasticsearch python package. '
'Please install it with `pip install elasticsearch`.')
return
self.index_name = index_name
self.llm_chain = llm_chain
self.drop_old = drop_old
_ssl_verify = ssl_verify or {}
self.elasticsearch_url = elasticsearch_url
self.ssl_verify = _ssl_verify
try:
self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify)
except ValueError as e:
logger.error(f'Your elasticsearch client string is mis-formatted. Got error: {e}')
return
if drop_old:
try:
self.client.indices.delete(index=index_name)
except elasticsearch.exceptions.NotFoundError:
logger.info(f"Index '{index_name}' not found, nothing to delete.")
except Exception as e:
logger.error(f"Error occurred while trying to delete index '{index_name}': {e}")
logger.info(f"ElasticKeywordsSearch initialized with URL: {elasticsearch_url} and index: {index_name}")
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
refresh_indices: bool = True,
**kwargs: Any,
) -> List[str]:
try:
from elasticsearch.exceptions import NotFoundError
from elasticsearch.helpers import bulk
except ImportError:
raise ImportError('Could not import elasticsearch python package. '
'Please install it with `pip install elasticsearch`.')
requests = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
mapping = _default_text_mapping()
# check to see if the index already exists
try:
self.client.indices.get(index=self.index_name)
if texts and self.drop_old:
self.client.indices.delete(index=self.index_name)
self.create_index(self.client, self.index_name, mapping)
except NotFoundError:
# TODO would be nice to create index before embedding,
# just to save expensive steps for last
self.create_index(self.client, self.index_name, mapping)
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
request = {
'_op_type': 'index',
'_index': self.index_name,
'text': text,
'metadata': metadata,
'_id': ids[i],
}
requests.append(request)
bulk(self.client, requests)
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
return ids
def similarity_search(self,
query: str,
k: int = 4,
query_strategy: str = 'match_phrase',
must_or_should: str = 'should',
**kwargs: Any) -> List[Document]:
if k == 0:
# pm need to control
return []
docs_and_scores = self.similarity_search_with_score(query,
k=k,
query_strategy=query_strategy,
must_or_should=must_or_should,
**kwargs)
documents = [d[0] for d in docs_and_scores]
return documents
@staticmethod
def _relevance_score_fn(distance: float) -> float:
"""Normalize the distance to a score on a scale [0, 1]."""
# Todo: normalize the es score on a scale [0, 1]
return distance
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._relevance_score_fn
def similarity_search_with_score(self,
query: str,
k: int = 4,
query_strategy: str = 'match_phrase',
must_or_should: str = 'should',
**kwargs: Any) -> List[Tuple[Document, float]]:
if k == 0:
# pm need to control
return []
assert must_or_should in ['must', 'should'], 'only support must and should.'
# llm or jiaba extract keywords
if self.llm_chain:
keywords_str = self.llm_chain.run(query)
print('llm search keywords:', keywords_str)
try:
keywords = eval(keywords_str)
if not isinstance(keywords, list):
raise ValueError('Keywords extracted by llm is not list.')
except Exception as e:
print(str(e))
keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False)
else:
keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False)
logger.info('jieba search keywords:{}'.format(keywords))
match_query = {'bool': {must_or_should: []}}
for key in keywords:
match_query['bool'][must_or_should].append({query_strategy: {'text': key}})
response = self.client_search(self.client, self.index_name, match_query, size=k)
hits = [hit for hit in response['hits']['hits']]
docs_and_scores = [
Document(
page_content=hit['_source']['text'],
metadata={**hit['_source']['metadata'], 'relevance_score': hit['_score']}
)
for hit in hits]
return docs_and_scores
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
index_name: Optional[str] = None,
refresh_indices: bool = True,
llm: Optional[BaseLLM] = None,
prompt: Optional[PromptTemplate] = DEFAULT_PROMPT,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> ElasticKeywordsSearch:
elasticsearch_url = get_from_dict_or_env(kwargs, 'elasticsearch_url', 'ELASTICSEARCH_URL')
if 'elasticsearch_url' in kwargs:
del kwargs['elasticsearch_url']
index_name = index_name or uuid.uuid4().hex
if llm:
llm_chain = LLMChain(llm=llm, prompt=prompt)
vectorsearch = cls(elasticsearch_url,
index_name,
llm_chain=llm_chain,
drop_old=drop_old,
**kwargs)
else:
vectorsearch = cls(elasticsearch_url, index_name, drop_old=drop_old, **kwargs)
vectorsearch.add_texts(texts,
metadatas=metadatas,
ids=ids,
refresh_indices=refresh_indices)
return vectorsearch
def create_index(self, client: Any, index_name: str, mapping: Dict) -> None:
version_num = client.info()['version']['number'][0]
version_num = int(version_num)
if version_num >= 8:
client.indices.create(index=index_name, mappings=mapping)
else:
client.indices.create(index=index_name, body={'mappings': mapping})
def client_search(self, client: Any, index_name: str, script_query: Dict, size: int) -> Any:
version_num = client.info()['version']['number'][0]
version_num = int(version_num)
if version_num >= 8:
response = client.search(index=index_name, query=script_query, size=size, timeout='5s')
else:
response = client.search(index=index_name, body={'query': script_query, 'size': size}, timeout='5s')
return response
def delete(self, **kwargs: Any) -> None:
# TODO: Check if this can be done in bulk
self.client.indices.delete(index=self.index_name)
def read_text(filepath):
with open(filepath) as f:
txt = f.read()
return txt
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--elasticsearch_url',
type=str,
default='http://127.0.0.1:9200')
parser.add_argument(
'--index_name',
type=str,
default='dcu_knowledge_base')
parser.add_argument(
'--query',
type=str,
default='介绍下K100_AI?')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
elastic_search = ElasticKeywordsSearch(
elasticsearch_url=args.elasticsearch_url,
index_name=args.index_name,
drop_old=False
)
# texts = []
# file_list = ['/home/zhangwq/data/doc_new/preprocess/dbaa1604.text',
# '/home/zhangwq/data/doc_new/preprocess/a8e2e50d.text',
# '/home/zhangwq/data/doc_new/preprocess/a3fdf916.text',
# '/home/zhangwq/data/doc_new/preprocess/9d2683f3.text',
# '/home/zhangwq/data/doc_new/preprocess/2584c250.text']
# for file in file_list:
# text = read_text(file)
# texts.append(text)
# metadatas = [
# {"source": "白皮书-K100.pdf", "type": "text"},
# {"source": "DCU人工智能基础软件系统DAS1.0介绍.pdf", "type": "text"},
# {"source": "202404-DCU优势测试项.pdf", "type": "text"},
# {"source": "202301-达芬奇架构简介.pdf", "type": "text"},
# {"source": "曙光DCU在大模型方面的布局与应用.docx", "type": "text"},
# ]
# ids = ["doc1", "doc2", "doc3", "doc4", "doc5"]
# elastic_search.add_texts(texts, metadatas=metadatas, ids=ids)
search_results = elastic_search.similarity_search_with_score(args.query, k=5)
for result in search_results:
logger.debug('Query: {} \nDoc: {} \nScore: {}'.format(args.query, result[0], result[1]))
import matplotlib.pyplot as plt
import numpy as np
import os
from matplotlib import font_manager
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用黑体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 数据
pollution_levels = ['原始', '0.5倍', '1倍', '1.5倍', '2倍']
vector_search = {
'hit_rate': [91.24, 95.99, 92.80, 86.45, 72.10],
'mrr': [97.66, 96.26, 95.67, 94.98, 96.23]
}
keyword_search = {
'hit_rate': [98.77, 98.28, 96.86, 93.93, 84.14],
'mrr': [96.97, 96.02, 94.03, 92.18, 86.23]
}
hybrid_search = {
'hit_rate': [99.80, 99.62, 99.37, 94.49, 94.10],
'mrr': [94.14, 94.55, 92.09, 89.42, 82.12]
}
hybrid_search_after_decoupling = {
'hit_rate': [99.71, 99.48, 99.04, 97.89, 92.19],
'mrr': [95.88, 94.12, 91.65, 87.90, 79.86]
}
# 创建图表
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(24, 6))
# fig.suptitle('搜索方法比较实验', fontsize=20)
# 函数来绘制单个子图
def plot_subplot(ax, data, title):
ax.plot(pollution_levels, data['hit_rate'], 'ro-', label='命中率')
ax.plot(pollution_levels, data['mrr'], 'bo-', label='平均倒数排名')
ax.set_title(title, fontsize=20)
ax.set_ylim(70, 100)
ax.set_xlabel('干扰程度', fontsize=18)
ax.set_ylabel('百分比', fontsize=18)
ax.legend(fontsize=16)
ax.grid(True, linestyle='--', alpha=0.7)
ax.tick_params(axis='both', which='major', labelsize=16)
# 设置x轴标签旋转
ax.set_xticklabels(pollution_levels, ha='right')
# 绘制子图
plot_subplot(ax1, vector_search, '向量搜索')
plot_subplot(ax2, keyword_search, '关键字搜索')
plot_subplot(ax3, hybrid_search, '混合搜索')
plot_subplot(ax4, hybrid_search_after_decoupling, '混合搜索_架构拆解')
# 调整布局
plt.tight_layout()
# 保存图表
file_path = os.path.join(os.getcwd(), 'search_methods_comparison.png')
plt.savefig(file_path, dpi=300, bbox_inches='tight')
print(f"图表已保存为: {file_path}")
# 检查文件是否成功创建
if os.path.exists(file_path):
print(f"文件成功创建在: {file_path}")
print(f"文件大小: {os.path.getsize(file_path)} 字节")
else:
print(f"文件创建失败,请检查权限或其他问题")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment