"backend/apps/webui/routers/auths.py" did not exist on "6a63c94153dd76f41857019530ef0054913b5cdd"
Commit 6088e14e authored by chenych's avatar chenych
Browse files

jifu v1.0

parent 2397728d
import matplotlib.pyplot as plt
import numpy as np
# 数据
groups = ['原始查询', '0.5倍污染', '1倍污染', '1.5倍污染', '2倍污染']
times_before = [3716.4464, 7040.2188, 4724.0401, 7127.4622, 5103.7318] # 拆解前的时间
times_after = [1981.7389, 2085.3399, 2116.5892, 2205.7316, 3006.2746] # 拆解后的时间
# 设置中文字体,请确保你的系统中安装了这个字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 创建图表
fig, ax = plt.subplots(figsize=(10, 6))
# 设置柱的宽度
width = 0.35
# 创建柱状图
x = np.arange(len(groups))
rects1 = ax.bar(x - width/2, times_before, width, label='拆解前', color='white', edgecolor='black')
rects2 = ax.bar(x + width/2, times_after, width, label='拆解后', hatch='//', color='white', edgecolor='black')
# 设置y轴标签
ax.set_ylabel('时间')
# 设置x轴刻度和标签
ax.set_xticks(x)
ax.set_xticklabels(groups)
# 添加图例
ax.legend()
# 设置标题
ax.set_title('拆解前后时间对比')
# 在右上角添加图例说明
# ax.text(0.95, 0.95, '质量\n修改', transform=ax.transAxes, fontsize=9,
# verticalalignment='top', horizontalalignment='right',
# bbox=dict(boxstyle='round', facecolor='white', edgecolor='black'))
# 调整布局
plt.tight_layout()
# 保存图表到文件
plt.savefig('time_comparison_chart.png', dpi=300, bbox_inches='tight')
# 关闭图表(释放内存)
plt.close(fig)
\ No newline at end of file
import argparse
import fitz
import re
import os
import time
import pandas as pd
import hashlib
import textract
import shutil
import configparser
import json
from multiprocessing import Pool
from typing import List
from loguru import logger
from BCEmbedding.tools.langchain import BCERerank
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.faiss import FAISS
from torch.cuda import empty_cache
from bs4 import BeautifulSoup
from elastic_keywords_search import ElasticKeywordsSearch
from retriever import Retriever
def check_envs(args):
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
class DocumentName:
def __init__(self, directory: str, name: str, category: str):
self.directory = directory
self.prefix = name.replace('/', '_')
self.basename = os.path.basename(name)
self.origin_path = os.path.join(directory, name)
self.copy_path = ''
self._category = category
self.status = True
self.message = ''
def __str__(self):
return '{},{},{},{}\n'.format(self.basename, self.copy_path, self.status,
self.message)
class DocumentProcessor:
def __init__(self):
self.image_suffix = ['.jpg', '.jpeg', '.png', '.bmp']
self.md_suffix = '.md'
self.text_suffix = ['.txt', '.text']
self.excel_suffix = ['.xlsx', '.xls', '.csv']
self.pdf_suffix = '.pdf'
self.ppt_suffix = '.pptx'
self.html_suffix = ['.html', '.htm', '.shtml', '.xhtml']
self.word_suffix = ['.docx', '.doc']
self.json_suffix = '.json'
def md5(self, filepath: str):
hash_object = hashlib.sha256()
with open(filepath, 'rb') as file:
chunk_size = 8192
while chunk := file.read(chunk_size):
hash_object.update(chunk)
return hash_object.hexdigest()[0:8]
def summarize(self, files: list):
success = 0
skip = 0
failed = 0
for file in files:
if file.status:
success += 1
elif file.message == 'skip':
skip += 1
else:
logger.info('{}文件异常, 异常信息: {} '.format(file.origin_path, file.message))
failed += 1
logger.info('解析{}文件,成功{}个,跳过{}个,异常{}个'.format(len(files), success,
skip, failed))
def read_file_type(self, filepath: str):
filepath = filepath.lower()
if filepath.endswith(self.pdf_suffix):
return 'pdf'
if filepath.endswith(self.md_suffix):
return 'md'
if filepath.endswith(self.ppt_suffix):
return 'ppt'
if filepath.endswith(self.json_suffix):
return 'json'
for suffix in self.image_suffix:
if filepath.endswith(suffix):
return 'image'
for suffix in self.text_suffix:
if filepath.endswith(suffix):
return 'text'
for suffix in self.word_suffix:
if filepath.endswith(suffix):
return 'word'
for suffix in self.excel_suffix:
if filepath.endswith(suffix):
return 'excel'
for suffix in self.html_suffix:
if filepath.endswith(suffix):
return 'html'
return None
def scan_directory(self, repo_dir: str):
documents = []
for directory, _, names in os.walk(repo_dir):
for name in names:
category = self.read_file_type(name)
if category is not None:
documents.append(
DocumentName(directory=directory, name=name, category=category))
return documents
def read(self, filepath: str):
file_type = self.read_file_type(filepath)
text = ''
if not os.path.exists(filepath):
return text
try:
if file_type == 'md' or file_type == 'text':
text = []
with open(filepath) as f:
txt = f.read()
cleaned_txt = re.sub(r'\n\s*\n', '\n\n', txt)
text.append(cleaned_txt)
elif file_type == 'pdf':
text += self.read_pdf(filepath)
text = re.sub(r'\n\s*\n', '\n\n', text)
elif file_type == 'excel':
text += self.read_excel(filepath)
elif file_type == 'word' or file_type == 'ppt':
# https://stackoverflow.com/questions/36001482/read-doc-file-with-python
# https://textract.readthedocs.io/en/latest/installation.html
text = textract.process(filepath).decode('utf8')
text = re.sub(r'\n\s*\n', '\n\n', text)
if file_type == 'ppt':
text = text.replace('\n', ' ')
elif file_type == 'html':
with open(filepath) as f:
soup = BeautifulSoup(f.read(), 'html.parser')
text += soup.text
elif filepath.endswith('.json'):
# 打开JSON文件进行读取
with open(filepath, 'r', encoding='utf-8') as file:
# 读取文件的所有行
text = file.readlines()
except Exception as e:
logger.error((filepath, str(e)))
return '', e
return text, None
def read_excel(self, filepath: str):
table = None
if filepath.endswith('.csv'):
table = pd.read_csv(filepath)
else:
table = pd.read_excel(filepath)
if table is None:
return ''
json_text = table.dropna(axis=1).to_json(force_ascii=False)
return json_text
def read_pdf(self, filepath: str):
# load pdf and serialize table
text = ''
with fitz.open(filepath) as pages:
for page in pages:
text += page.get_text()
tables = page.find_tables()
for table in tables:
tablename = '_'.join(
filter(lambda x: x is not None and 'Col' not in x,
table.header.names))
pan = table.to_pandas()
json_text = pan.dropna(axis=1).to_json(force_ascii=False)
text += tablename
text += '\n'
text += json_text
text += '\n'
return text
def read_and_save(file: DocumentName, file_opr: DocumentProcessor):
try:
if os.path.exists(file.copy_path):
# already exists, return
logger.info('{} already processed, output file: {}, skip load'
.format(file.origin_path, file.copy_path))
return
logger.info('reading {}, would save to {}'.format(file.origin_path,
file.copy_path))
content, error = file_opr.read(file.origin_path)
if error is not None:
logger.error('{} load error: {}'.format(file.origin_path, str(error)))
return
if content is None or len(content) < 1:
logger.warning('{} empty, skip save'.format(file.origin_path))
return
cleaned_content = re.sub(r'\n\s*\n', '\n\n', content)
with open(file.copy_path, 'w') as f:
f.write(os.path.splitext(file.basename)[0] + '\n')
f.write(cleaned_content)
except Exception as e:
logger.error(f"Error in read_and_save: {e}")
class FeatureDataBase:
def __init__(self,
embeddings: HuggingFaceEmbeddings,
reranker: BCERerank,
reject_throttle=-1) -> None:
# logger.debug('loading text2vec model..')
self.embeddings = embeddings
self.reranker = reranker
self.compression_retriever = None
self.rejecter = None
self.retriever = None
self.reject_throttle = reject_throttle if reject_throttle else -1
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1068, chunk_overlap=32)
def get_documents(self, text, file):
# if len(text) <= 1:
# return []
chunks = self.text_splitter.create_documents(text)
documents = []
for chunk in chunks:
# `source` is for return references
# `read` is for LLM response
chunk.metadata = {'source': file.basename, 'read': file.origin_path}
documents.append(chunk)
return documents
def build_database(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None):
feature_dir = os.path.join(work_dir, 'db_response')
if not os.path.exists(feature_dir):
os.makedirs(feature_dir)
documents = []
texts_for_es = []
metadatas_for_es = []
ids_for_es = []
for i, file in enumerate(files):
if not file.status:
continue
# 读取每个file
text, error = file_opr.read(file.copy_path)
if error is not None:
file.status = False
file.message = str(error)
continue
file.message = str(text[0])
texts_for_es.append(text[0])
metadatas_for_es.append({'source': file.basename, 'read': file.origin_path})
ids_for_es.append(str(i))
document = self.get_documents(text, file)
documents += document
logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents'
.format(i + 1, len(files), file.basename, len(document)))
if elastic_search is not None:
logger.debug('ES database pipeline register {} documents into database...'.format(len(texts_for_es)))
es_time_before_register = time.time()
elastic_search.add_texts(texts_for_es, metadatas=metadatas_for_es, ids=ids_for_es)
es_time_after_register = time.time()
logger.debug('ES database pipeline take time: {} '.format(es_time_after_register - es_time_before_register))
logger.debug('Vector database pipeline register {} documents into database...'.format(len(documents)))
ve_time_before_register = time.time()
vs = FAISS.from_documents(documents, self.embeddings)
vs.save_local(feature_dir)
ve_time_after_register = time.time()
logger.debug('Vector database pipeline take time: {} '.format(ve_time_after_register - ve_time_before_register))
def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor):
preproc_dir = os.path.join(work_dir, 'preprocess')
if not os.path.exists(preproc_dir):
os.makedirs(preproc_dir)
pool = Pool(processes=16)
for idx, file in enumerate(files):
if not os.path.exists(file.origin_path):
file.status = False
file.message = 'skip not exist'
continue
if file._category == 'image':
file.status = False
file.message = 'skip image'
elif file._category in ['pdf', 'word', 'ppt', 'html', 'excel']:
# read pdf/word/excel file and save to text format
md5 = file_opr.md5(file.origin_path)
file.copy_path = os.path.join(preproc_dir,
'{}.text'.format(md5))
pool.apply_async(read_and_save, args=(file, file_opr))
elif file._category in ['md', 'text']:
# rename text files to new dir
file.copy_path = os.path.join(
preproc_dir,
file.origin_path.replace('/', '_')[-84:])
try:
shutil.copy(file.origin_path, file.copy_path)
file.status = True
file.message = 'preprocessed'
except Exception as e:
file.status = False
file.message = str(e)
elif file._category in ['json']:
file.status = True
file.copy_path = file.origin_path
file.message = 'preprocessed'
else:
file.status = False
file.message = 'skip unknown format'
pool.close()
logger.debug('waiting for preprocess read finish..')
pool.join()
# check process result
for file in files:
if file._category in ['pdf', 'word', 'excel']:
if os.path.exists(file.copy_path):
file.status = True
file.message = 'preprocessed'
else:
file.status = False
file.message = 'read error'
def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None):
self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr)
self.build_database(files=files, work_dir=work_dir, file_opr=file_opr, elastic_search=elastic_search)
def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor):
feature_dir = os.path.join(work_dir, 'db_response')
if not os.path.exists(feature_dir):
os.makedirs(feature_dir)
documents = []
for i, file in enumerate(files):
logger.debug('{}/{}.. register 《{}》 into database...'.format(i + 1, len(files), file.basename))
if not file.status:
continue
# 读取每个file
text, error = file_opr.read(file.copy_path)
if error is not None:
file.status = False
file.message = str(error)
continue
logger.info(str(len(text)), text, str(text[0]))
file.message = str(text[0])
# file.message = str(len(text))
# logger.info('{} content length {}'.format(
# file._category, len(text)))
documents += self.get_documents(text, file)
if documents:
vs = FAISS.from_documents(documents, self.embeddings)
if faiss:
faiss.merge_from(vs)
faiss.save_local(feature_dir)
else:
vs.save_local(feature_dir)
def test_reject(retriever: Retriever):
"""Simple test reject pipeline."""
real_questions = [
'姚明是谁?',
'CBBA是啥?',
'差多少嘞?',
'cnn 的全称是什么?',
'transformer啥意思?',
'成都有什么好吃的推荐?',
'树博士是什么?',
'白马非马啥意思?',
'mmpose 如何安装?',
'今天天气如何?',
'写一首五言律诗?',
'先有鸡还是先有蛋?',
'如何在Gromacs中进行蛋白质的动态模拟?',
'wy-vSphere 7 海光平台兼容补丁?',
'在Linux系统中,如何进行源码包的安装?'
]
for example in real_questions:
relative, _ = retriever.is_relative(example)
if relative:
logger.warning(f'process query: {example}')
retriever.query(example)
empty_cache()
else:
logger.error(f'reject query: {example}')
empty_cache()
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument('--work_dir',
type=str,
default='',
help='自定义.')
parser.add_argument(
'--repo_dir',
type=str,
default='',
help='需要读取的文件目录.')
parser.add_argument(
'--config_path',
default='./ai/rag/config.ini',
help='config目录')
parser.add_argument(
'--DCU_ID',
default=[7],
help='设置DCU')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
log_file_path = os.path.join(args.work_dir, 'application.log')
logger.add(log_file_path, rotation='10MB', compression='zip')
check_envs(args)
config = configparser.ConfigParser()
config.read(args.config_path)
# only init vector retriever
embedding_model_path = config.get('rag', 'embedding_model_path') or None
reranker_model_path = config.get('rag', 'reranker_model_path') or None
if embedding_model_path and reranker_model_path:
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cuda'},
encode_kwargs={
'batch_size': 1,
'normalize_embeddings': True
})
embeddings.client = embeddings.client.half()
reranker_args = {
'model': reranker_model_path,
'top_n': int(config['rag']['vector_top_k']),
'device': 'cuda',
'use_fp16': True
}
reranker = BCERerank(**reranker_args)
fs_init = FeatureDataBase(embeddings=embeddings,
reranker=reranker)
# init es retriever, drop_old means build new one or updata the 'index_name'
es_url = config.get('rag', 'es_url')
index_name = config.get('rag', 'index_name')
elastic_search = ElasticKeywordsSearch(
elasticsearch_url=es_url,
index_name=index_name,
drop_old=True)
# walk all files in repo dir
file_opr = DocumentProcessor()
files = file_opr.scan_directory(repo_dir=args.repo_dir)
fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr, elastic_search=elastic_search)
file_opr.summarize(files)
# del fs_init
# with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f:
# positive_sample = json.load(f)
# with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f:
# negative_sample = json.load(f)
#
# with open(os.path.join(args.work_dir, 'sample', 'positive.txt'), 'r', encoding='utf-8') as file:
# positive_sample = []
# for line in file:
# positive_sample.append(line.strip())
#
# with open(os.path.join(args.work_dir, 'sample', 'negative.txt'), 'r', encoding='utf-8') as file:
# negative_sample = []
# for line in file:
# negative_sample.append(line.strip())
#
# test_reject(retriever)
BCEmbedding
elasticsearch==8.12.0
faiss-cpu
jieba
langchain
langchain_community==0.0.38
loguru
scikit-learn==1.5.0
\ No newline at end of file
import argparse
import configparser
import os
import requests
import time
import uvicorn
from fastapi import FastAPI, Request
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from sklearn.metrics import precision_recall_curve
from loguru import logger
from BCEmbedding.tools.langchain import BCERerank
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores.utils import DistanceStrategy
from requests.exceptions import RequestException
from elastic_keywords_search import ElasticKeywordsSearch
app = FastAPI()
class Retriever:
def __init__(self, config) -> None:
self.mix, self.es, self.vector = None, None, None
work_dir = config['default']['work_dir']
self.es_top_k = int(config['rag']['es_top_k'])
self.vector_top_k = int(config['rag']['vector_top_k'])
embedding_model_path = config.get('rag', 'embedding_model_path') or None
reranker_model_path = config.get('rag', 'reranker_model_path') or None
es_url = config.get('rag', 'es_url') or None
index_name = config.get('rag', 'index_name') or None
# Mix
if embedding_model_path and reranker_model_path and es_url and index_name:
self.init_mix_retriever(work_dir, embedding_model_path, reranker_model_path, es_url, index_name)
# ES
elif not embedding_model_path or not reranker_model_path:
if self.is_es_available(es_url, index_name):
self.es_retriever = ElasticKeywordsSearch(es_url, index_name, drop_old=False)
self.weights = [0.5, 0.5]
self.es = True
logger.info('Initializing ES retriever alone!')
# Vector
elif not es_url or not index_name:
self.init_vector_retriever(work_dir, embedding_model_path, reranker_model_path)
self.vector = True
logger.info('Initializing Vector retriever alone!')
else:
raise ValueError(
"Incomplete configuration. Please specify all required parameters for either vector or ES retrieval.")
def init_vector_retriever(self, work_dir, embedding_model_path, reranker_model_path):
logger.info('loading test2vec and rerank models')
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cuda'},
encode_kwargs={
'batch_size': 1,
'normalize_embeddings': True
})
# half
self.embeddings.client = self.embeddings.client.half()
reranker_args = {
'model': reranker_model_path,
'top_n': self.vector_top_k,
'device': 'cuda',
'use_fp16': True
}
self.reranker = BCERerank(**reranker_args)
self.vector_store = FAISS.load_local(
os.path.join(work_dir, 'db_response'),
embeddings=self.embeddings,
allow_dangerous_deserialization=True,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
retriever = self.vector_store.as_retriever(
search_type='similarity',
search_kwargs={
'score_threshold': 0.15,
'k': 30
}
)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=retriever)
def init_mix_retriever(self, work_dir, embedding_model_path, reranker_model_path, es_url, index_name):
if self.is_es_available(es_url, index_name):
self.es_retriever = ElasticKeywordsSearch(es_url, index_name, drop_old=False)
self.weights = [0.5, 0.5]
self.init_vector_retriever(work_dir, embedding_model_path, reranker_model_path)
self.mix = True
logger.info('Initializing Mix retriever!')
else:
self.init_vector_retriever(work_dir, embedding_model_path, reranker_model_path)
self.vector = True
logger.info('Initializing Vector retriever alone!')
def is_es_available(self, url, index_name, timeout=5):
try:
response = requests.get(f"{url}/_cluster/health", timeout=timeout)
if response.status_code == 200:
index_response = requests.head(f"{url}/{index_name}", timeout=timeout)
if index_response.status_code == 200:
logger.info(f"The index:'{index_name}' exist!")
return True
elif index_response.status_code == 404:
logger.warning(f"The index:'{index_name}' not exist!")
else:
logger.error(f"Unexpected status code when checking index: {index_response.status_code}")
else:
logger.error(f"Elasticsearch service returned non-200 status code: {response.status_code}")
except RequestException as e:
logger.error(f"Error connecting to Elasticsearch service: {e}")
return False
def weighted_reciprocal_rank(self, es_docs, vector_docs):
# Create a union of all unique documents in the input doc_lists
all_documents = set()
for vector_doc in vector_docs:
all_documents.add(vector_doc.page_content)
for es_doc in es_docs:
all_documents.add(es_doc.page_content)
rrf_score_dic = {doc: 0.0 for doc in all_documents}
for rank, vector_doc in enumerate(vector_docs, start=1):
rrf_score = self.weights[1] * (1 / (rank + 60))
rrf_score_dic[vector_doc.page_content] += rrf_score
for rank, es_doc in enumerate(es_docs, start=1):
rrf_score = self.weights[0] * (1 / (rank + 60))
rrf_score_dic[es_doc.page_content] += rrf_score
sorted_documents = sorted(rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True)
# Map the sorted page_content back to the original document objects
page_content_to_doc_map = {}
for doc in es_docs:
page_content_to_doc_map[doc.page_content] = doc
for doc in vector_docs:
page_content_to_doc_map[doc.page_content] = doc
sorted_docs = [page_content_to_doc_map[page_content] for page_content in sorted_documents]
return sorted_docs
def remove_duplicates(self, sorted_docs):
seen = set()
unique_docs = []
for doc in sorted_docs:
identifier = (
doc.metadata.get('source', ''),
doc.metadata.get('read', ''),
# doc.page_content # Need further testing
)
if identifier not in seen:
seen.add(identifier)
unique_docs.append(doc)
return unique_docs
def hybrid_retrieval(self, query):
es_docs = self.es_retriever.similarity_search_with_score(query, k=self.es_top_k)
vector_docs = self.query(query)
sorted_docs = self.weighted_reciprocal_rank(es_docs, vector_docs)
unique_docs = self.remove_duplicates(sorted_docs)
return unique_docs
def rag_workflow(self, query):
chunks = []
time_1 = time.time()
if self.mix:
chunks = self.hybrid_retrieval(query)
elif self.es:
chunks = self.es_retriever.similarity_search_with_score(query, k=self.es_top_k)
else:
chunks = self.query(query)
time_2 = time.time()
logger.debug(f'query:{query} \nchunks:{chunks} \ntimecost:{time_2 - time_1}')
return chunks
def query(self, question: str):
if question is None or len(question) < 1:
return None
if len(question) > 512:
logger.warning('input too long, truncate to 512')
question = question[0:512]
docs = self.compression_retriever.get_relevant_documents(question)
return docs
retriever = None
@app.post("/retrieve")
async def retrieve(request: Request):
data = await request.json()
query = data.get("query")
chunks = retriever.rag_workflow(query)
return {"chunks": chunks}
def rag_retrieve(args: str):
"""
启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 rag 检索服务.
"""
global retriever
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
os.environ["CUDA_VISIBLE_DEVICES"] = args.dcu_id
retriever = Retriever(config)
uvicorn.run(app, host="0.0.0.0", port=bind_port)
def parse_args():
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--config_path',
default='/ai/rag/config.ini',
help='config目录')
parser.add_argument(
'--dcu_id',
default=None,
help='设置DCU')
args = parser.parse_args()
return args
def main():
args = parse_args()
rag_retrieve(args)
if __name__ == '__main__':
main()
import requests
import argparse
import json
import re
import time
import os
from loguru import logger
def retrieve(query, url):
"""向后端服务器发送检索请求。"""
endpoint = f"{url}/retrieve"
payload = {"query": query}
try:
response = requests.post(endpoint, json=payload)
response.raise_for_status() # 对于错误的状态码抛出异常
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"发送请求时发生错误: {e}")
return None
def remove_punctuation(query):
return re.sub(r'[^\w\s]|[_]', '', query)
def check_hit_content(query, chunks):
logger.info(f"查询: {query}")
query_move = remove_punctuation(query)
chunks = chunks.get('chunks', '')
for rank, chunk in enumerate(chunks, 1):
content = remove_punctuation(chunk['page_content'])
if query_move.lower() in content.lower():
return True, 1 / rank
return False, None
def check_hit(query, chunks):
"""检查查询是否命中任何文档。"""
json2text_path = '/home/zhangwq/data/for_test_new/json2txt'
logger.info(f"查询: {query}")
chunks = chunks.get('chunks', '')
query_move = remove_punctuation(query)
for rank, chunk in enumerate(chunks, 1):
metadata = chunk.get('metadata', '')
source = metadata.get('source', '')
logger.info(f"文档: {source}")
source_original = os.path.splitext(source)[0]
source_cleaned = remove_punctuation(source_original)
if re.match(r'json_obj_\d+\.txt', source):
# 处理 json_obj_{i}.txt 文件
json_file_path = os.path.join(json2text_path, source)
logger.info(f"检查 JSON 文件: {json_file_path}")
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
json_content = f.read()
json_content_cleaned = remove_punctuation(json_content)
if query_move.lower() in json_content_cleaned.lower():
logger.info(f"在 JSON 文件 {source} 中找到匹配!")
return True, 1 / rank
except Exception as e:
logger.error(f"读取或处理 {json_file_path} 时出错: {str(e)}")
if query_move.lower() in source_cleaned.lower() or query_move in source_cleaned:
return True, 1 / rank
return False, None
def process_queries_with_polluted_file(file_path, url):
import ast
"""处理文件中的每一行查询。"""
total_queries = 0
hits = 0
total_inverse_rank = 0
time1 = time.time()
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
try:
# data = json.loads(line.strip())
data = ast.literal_eval(line.strip())
polluted_query = data['Polluted']
original_query = data['original']
except json.JSONDecodeError:
logger.error(f"无法解析行: {line.strip()}")
continue
except KeyError:
logger.error(f"行缺少必要的键: {line.strip()}")
continue
total_queries += 1
logger.debug(f"处理查询 {total_queries}: 污染查询 = '{polluted_query}', 原始查询 = '{original_query}'")
results = retrieve(polluted_query, url)
if results:
hit, inverse_rank = check_hit(original_query, results)
if hit:
hits += 1
total_inverse_rank += inverse_rank
logger.debug(f"命中! 倒数排名: {inverse_rank:.4f}")
else:
logger.debug("未命中")
logger.debug("-" * 50)
hit_rate = hits / total_queries if total_queries > 0 else 0
average_inverse_rank = total_inverse_rank / hits if hits > 0 else 0
time2 = time.time()
elapsed_time = time2 - time1
logger.debug(f"总查询数: {total_queries}")
logger.debug(f"命中数: {hits}")
logger.debug(f"命中率: {hit_rate:.2%}")
logger.debug(f"平均倒数排名: {average_inverse_rank:.4f}")
logger.debug(f"花费时间: {elapsed_time:.4f} 秒")
def process_queries(file_path, url):
"""处理文件中的每一行查询。"""
total_queries = 0
hits = 0
total_inverse_rank = 0
time1 = time.time()
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
query = line.strip()
if not query:
continue
total_queries += 1
results = retrieve(query, url)
if results:
hit, inverse_rank = check_hit_content(query, results)
if hit:
hits += 1
total_inverse_rank += inverse_rank
logger.debug(f"命中! 倒数排名: {inverse_rank:.4f}")
else:
logger.debug("未命中")
logger.debug("-" * 50)
hit_rate = hits / total_queries if total_queries > 0 else 0
average_inverse_rank = total_inverse_rank / hits if hits > 0 else 0
time2 = time.time()
elapsed_time = time2 - time1
logger.debug(f"总查询数: {total_queries}")
logger.debug(f"命中数: {hits}")
logger.debug(f"命中率: {hit_rate:.2%}")
logger.debug(f"平均倒数排名: {average_inverse_rank:.4f}")
logger.debug(f"花费时间: {elapsed_time:.4f} 秒")
def main():
parser = argparse.ArgumentParser(description="RAG检索系统的客户端。")
parser.add_argument("--url", default="http://127.0.0.1:9003", help="后端服务器的URL")
parser.add_argument("--query", default=None, help="单个查询")
parser.add_argument("--file", default=None, help="包含查询的文本文件路径")
parser.add_argument("--polluted_file", default='/home/zhangwq/data/query_test/N_原句_N.txt',
help="包含污染查询的文本文件路径")
args = parser.parse_args()
logger.info(f"连接到服务器: {args.url}")
if args.file:
log_file_path = '/home/zhangwq/project/shu_new/ai/rag/eval/RAG_hit_rate_and_average_inverse_rank-original_query_es_alone.log'
logger.add(log_file_path, rotation='20MB', compression='zip')
process_queries(args.file, args.url)
elif args.polluted_file:
log_file_path = '/home/zhangwq/project/shu_new/ai/rag/eval/RAG_hit_rate_and_average_inverse_rank-N_原句_N_es_alone.log'
logger.add(log_file_path, rotation='20MB', compression='zip')
process_queries_with_polluted_file(args.polluted_file, args.url)
elif args.query:
results = retrieve(args.query, args.url)
if results:
hit, inverse_rank = check_hit_content(args.query, results)
if hit:
logger.info(f"命中! 倒数排名: {inverse_rank:.4f}")
else:
logger.info("未命中")
else:
logger.error("请提供查询或查询文件")
if __name__ == "__main__":
main()
\ No newline at end of file
accelerate
aiohttp
argcomplete==1.10.3
BCEmbedding
beautifulsoup4
chardet==3.0.4
compressed-rtf==1.0.6
docx2txt==0.8
ebcdic==1.1.1
extract-msg==0.28.7
elasticsearch==8.12.0
faiss-cpu
fastapi==0.112.2
huggingface-hub
imapclient==2.1.0
jieba
langchain==0.1.20
langchain-community==0.0.38
langchain-openai==0.1.7
......@@ -11,15 +20,26 @@ langchain-core==0.1.52
langchain-text-splitters==0.0.1
langsmith==0.1.57
loguru
olefile==0.47
openpyxl==3.1.5
pandas==2.0.1
pdfminer.six==20191110
pycryptodome==3.20.0
PyMuPDF==1.24.3
python-pptx==0.6.23
requests
scikit-learn==1.5.0
selenium==4.12.0
sentence_transformers
six==1.12.0
SpeechRecognition==3.8.1
tenacity==8.3.0
textract==1.6.5
tiktoken==0.7.0
tenacity==8.3.0
tokenizers==0.15.2
transformers
transformers==4.38.0
tzlocal==5.2
unstructured==0.11.2
PyMuPDF==1.24.3
\ No newline at end of file
uvicorn==0.30.6
XlsxWriter==3.2.0
xlrd==1.2.0
\ No newline at end of file
import json
import requests
from loguru import logger
import argparse
def start(query):
url = 'http://127.0.0.1:8888/work'
try:
header = {'Content-Type': 'application/json'}
# Add history to data
data = {
'query': query
}
resp = requests.post(url,
headers=header,
data=json.dumps(data),
timeout=300)
if resp.status_code != 200:
raise Exception(str((resp.status_code, resp.reason)))
return resp.json()['reply'], resp.json()['references']
except Exception as e:
logger.error(str(e))
return ''
def parse_args():
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--query',
default='输入用户问题',
help='')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
reply, ref = start(args.query)
logger.debug('reply: {} \nref: {} '.format(reply, ref))
import os
import argparse
import asyncio
import bisect
import configparser
import json
import subprocess
from aiohttp import web
import configparser
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from loguru import logger
from llm_service import Worker
from scipy.stats import rankdata
divisible_by_32 = [1, 2, 4, 8, 16, 32]
recv_file_path = "%s/upload"
app = FastAPI()
def workflow(args):
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
try:
assistant = Worker(config)
except Exception as e:
raise (e)
@app.post("/work")
async def work(request: Request):
input_json = await request.json()
query = input_json['query']
history = input_json.get('history', [])
try:
code, reply, references = await assistant.produce_response(config,
query=query,
history=history)
except Exception as e:
logger.error(e)
reply = "服务异常"
references = []
return JSONResponse({'reply': reply, 'references': references})
@app.post("/stream")
async def stream(request: Request):
input_json = await request.json()
query = input_json['query']
history = input_json.get('history', [])
async def event_generator():
try:
code, reply, references = await assistant.produce_response(config,
query=query,
history=history,
stream=True)
except Exception as e:
logger.error(e)
yield "data: 服务异常\n\n"
yield 'event: end\n data: End of stream\n\n'
return
word = 'data: %s\n\n'
try:
async for request_output in reply:
text = json.dumps(request_output)
data = (word % text).encode('utf-8')
yield data
yield 'event: end\n data: End of stream\n\n'
except (asyncio.CancelledError, ConnectionResetError) as e:
logger.debug('user interrupt')
return
return StreamingResponse(event_generator(), media_type="text/event-stream")
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=bind_port)
def auto_select_dcu(config):
# Read threshold in config file
mem_threshold = config.getint('default', 'mem_threshold')
dcu_threshold = config.getint('default', 'dcu_threshold')
# Get dcu usage
process = subprocess.Popen("hy-smi | grep '^[0-9]' | awk '{print $1,$6,$7}' | sed 's/%//g'", shell=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
......@@ -35,14 +98,11 @@ def auto_select_dcu(config):
continue
logger.debug("dcu id:%s, mem usage: %s dcu usage: %s" % (dcu_info[0], dcu_info[1], dcu_info[2]))
dcu_map[dcu_info[0]] = [int(dcu_info[1]), int(dcu_info[2])]
# Select dcu count must be divisible by 32.
# TODO temporary use 40% of available count
count = round(len(dcu_map.keys()) * 0.4)
if not count:
logger.error("There is no available dcu device, can not start the service.")
raise Exception("There is no available dcu device, can not start the service.")
insert_index = bisect.bisect_left(divisible_by_32, count)
if insert_index > 0 and count != divisible_by_32[insert_index]:
index = insert_index - 1
elif count == divisible_by_32[insert_index]:
......@@ -50,10 +110,8 @@ def auto_select_dcu(config):
else:
index = 0
select_count = divisible_by_32[index]
# Based on the ranking of memory and dcu usage.
dcu_mem_use_rank = [item[0] for item in dcu_map.values()]
dcu_use_rank = [item[1] for item in dcu_map.values()]
# Calculate the final ranking
final_rank = [(name, dcu_mem_use_rank[i] + dcu_use_rank[i]) for i, name in enumerate(dcu_map.keys())]
sorted_rank = sorted(final_rank, key=lambda x: x[1])
sorted_dcu_ids = [item[0] for item in sorted_rank]
......@@ -62,73 +120,23 @@ def auto_select_dcu(config):
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {select_dcu_ids}")
return select_dcu_ids
def workflow(args):
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
dcu_ids = auto_select_dcu(config)
tensor_parallel_size = len(dcu_ids)
try:
assistant = Worker(config, tensor_parallel_size)
except Exception as e:
raise (e)
async def work(request):
input_json = await request.json()
query = input_json['query']
code, reply, references = assistant.produce_response(query=query,
history=[],
judgment=False)
return web.json_response({'reply': reply, 'references': references})
async def handle_upload(request):
reader = await request.multipart()
while True:
field = await reader.next()
if field is None:
break
filename = field.filename
# Save to server
save_path = recv_file_path % config['default']['work_dir']
if not os.path.exists(save_path):
os.makedirs(save_path)
file_path = os.path.join(save_path, filename)
with open(file_path, 'wb') as f:
while True:
chunk = await field.read_chunk()
if not chunk:
break
f.write(chunk)
logger.debug("成功接收文件:%s" % file_path)
# Call file parse process
assistant.agent.parse_file_and_merge(save_path)
return web.json_response({"reply": "成功接收文件:{filename}\n"})
app = web.Application()
app.add_routes([web.post('/work', work),
web.post('/upload', handle_upload)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def parse_args():
parser = argparse.ArgumentParser(description='Start all services.')
parser.add_argument('--config_path',
default='./config.ini',
default='ai/config.ini',
help='Config directory')
parser.add_argument('--log_path',
default='./log/assistant.log',
default='',
help='Set log file path')
return parser.parse_args()
def main():
args = parse_args()
logger.add(sink=args.log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True)
log_path = '/var/log/assistant.log'
if args.log_path:
log_path = args.log_path
logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True)
workflow(args)
if __name__ == '__main__':
main()
[
"阿树是谁",
"具体在哪些位置进行修改?",
"你是谁?",
"1+1",
"你好",
"需要修改的内容的多吗",
"你好,介绍下自己",
"你能干什么",
"啊,不是吧",
"FCOS",
"那我怎么打印",
"?",
"MMSelfSup",
"百度",
"啥?",
"你是谁",
"那历史方面的问题呢",
"你被预设了什么人设",
"SIM卡鉴权过程中,跟踪区域码TAC起到了什么作用?请详述双向鉴权过程,尤其涉及TAC以及小区ID部分",
"DCOS",
"你帅吗",
"有新手入门教程吗",
"你说你是",
"你把两个问题合起来想一想",
"OpoenMMLab 会被取代吗",
"为什么",
"MMSelfSup有什么用?",
"群号有吗",
"有交流群吗",
"你会哪些问题啊",
"本垃圾在缓慢学习这些玩意",
"能不能找到上面的安装手册呢?",
"xtcocotools安装不了",
"你这是llm模型吗",
"在线难样本挖掘",
"ncnn全名是什么",
"先有鸡还是先有蛋?"
]
\ No newline at end of file
This diff is collapsed.
[
"请问DCU中如何查看显卡显存等信息?",
"Z100支持什么类型的计算精度?",
"请问DCU中如何查看显卡显存等信息?",
"能否概括一下DCU软件栈?",
"DCU如何实现分布式训练?",
"Rocm-smi的输出正常,但是rocminfo报错",
"什么是miopen?",
"怎样通过容器分割物理机上的DCU加速卡?",
"yolov5的内存出现memory access fault是因为什么?",
"为什么运行时找不到rocblas库",
"什么是服务器?",
"你能解释一下云服务器是什么吗?",
"在什么情况下可以提升服务器的性能?",
"什么是交换机?",
"负载均衡器是什么?",
"CDN是什么?",
"什么是人工智能(AI)?",
"深度学习指的是什么?",
"人工智能的伦理和道德问题有哪些?",
"人工智能与机器学习有何不同之处?",
"AI加速卡的工作原理是什么?",
"如何部署和使用AI加速卡?",
"DCU加速器计算单元微架构是什么样的?",
"有哪些HIP主机编程接口可供使用?",
"并行计算计算机的结构体系是什么样的?",
"MIGraphX都有哪些特性?",
"DTK是做什么用的?",
"曙光的HPC有哪些技术优势?",
"DCU在人工智能领域有哪些实际应用?",
"在DCU上如何配置onnxruntime环境?",
"为什么在本地可以运行,但在容器中无法运行?",
"MIOpen发生错误时应该如何处理?",
"如何进行DCU代码的移植?",
"DCU支持哪些深度学习框架或工具?",
"请问DCU上支持的大模型目前有哪些?",
"请问ac平台上slurm如何申请DCU资源?",
"请问DCU的AI生态包可以从哪里下载?"
]
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment