Commit fe3bae99 authored by Rayyyyy's avatar Rayyyyy
Browse files

update

parent de6f9f97
[default] [default]
work_dir = /path/to/your/ai/work_dir work_dir=/path/to/your/ai/work_dir
bind_port = 8888 bind_port=8888
mem_threshold = 50 mem_threshold=50
dcu_threshold = 100 dcu_threshold=100
[feature_database] [feature_database]
reject_throttle = 0.6165309870679363 reject_throttle=0.6165309870679363
embedding_model_path = /path/to/your/text2vec-large-chinese embedding_model_path=/path/to/your/text2vec-large-chinese
reranker_model_path = /path/to/your/bce-reranker-base_v1 reranker_model_path=/path/to/your/bce-reranker-base_v1
[llm] [llm]
local_llm_path = /path/to/your/internlm-chat-7b local_llm_path=/path/to/your/internlm-chat-7b
use_vllm = False use_vllm=False
\ No newline at end of file \ No newline at end of file
...@@ -56,4 +56,3 @@ plt.legend() ...@@ -56,4 +56,3 @@ plt.legend()
# 显示图表 # 显示图表
plt.show() plt.show()
...@@ -470,12 +470,8 @@ def parse_args(): ...@@ -470,12 +470,8 @@ def parse_args():
help='需要读取的文件目录.') help='需要读取的文件目录.')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/ai/config.ini', default='/home/AI_project/chat_demo/config.ini',
help='config目录') help='config目录')
parser.add_argument(
'--DCU_ID',
default=[4],
help='设置DCU')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -535,4 +531,4 @@ if __name__ == '__main__': ...@@ -535,4 +531,4 @@ if __name__ == '__main__':
# test # test
retriever = cache.get(reject_throttle=reject_throttle, retriever = cache.get(reject_throttle=reject_throttle,
work_dir=args.work_dir) work_dir=args.work_dir)
test_reject(retriever) test_reject(retriever)
\ No newline at end of file
...@@ -68,4 +68,4 @@ class LogManager: ...@@ -68,4 +68,4 @@ class LogManager:
file.write(f'{operation}: {outcome}\n') file.write(f'{operation}: {outcome}\n')
file.write('\n') file.write('\n')
except Exception as e: except Exception as e:
print(e) print(e)
\ No newline at end of file
...@@ -161,7 +161,7 @@ def parse_args(): ...@@ -161,7 +161,7 @@ def parse_args():
description='Feature store for processing directories.') description='Feature store for processing directories.')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/home/zhangwq/project/shu_new/ai/config.ini', default='/path/of/config.ini',
help='config目录') help='config目录')
parser.add_argument( parser.add_argument(
'--query', '--query',
......
...@@ -15,15 +15,6 @@ from sklearn.metrics import precision_recall_curve ...@@ -15,15 +15,6 @@ from sklearn.metrics import precision_recall_curve
from loguru import logger from loguru import logger
def check_envs(args):
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
class Retriever: class Retriever:
def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None: def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None:
self.reject_throttle = reject_throttle self.reject_throttle = reject_throttle
...@@ -304,12 +295,21 @@ def test_query(retriever: Retriever, real_questions): ...@@ -304,12 +295,21 @@ def test_query(retriever: Retriever, real_questions):
empty_cache() empty_cache()
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Feature store for processing directories.') description='Feature store for processing directories.')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/home/zhangwq/project/shu/ai/config.ini', default='/path/of/config.ini',
help='config目录') help='config目录')
parser.add_argument( parser.add_argument(
'--query', '--query',
...@@ -317,15 +317,16 @@ def parse_args(): ...@@ -317,15 +317,16 @@ def parse_args():
help='提问的问题.') help='提问的问题.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
default=[6], type=str,
help='设置DCU') default='0',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(): def main():
args = parse_args() args = parse_args()
check_envs(args) set_envs(args)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
...@@ -345,4 +346,3 @@ def main(): ...@@ -345,4 +346,3 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -6,6 +6,7 @@ from .feature_database import DocumentProcessor, FeatureDataBase ...@@ -6,6 +6,7 @@ from .feature_database import DocumentProcessor, FeatureDataBase
class ChatAgent: class ChatAgent:
def __init__(self, config, tensor_parallel_size) -> None: def __init__(self, config, tensor_parallel_size) -> None:
self.work_dir = config['default']['work_dir'] self.work_dir = config['default']['work_dir']
self.embedding_model_path = config['feature_database']['embedding_model_path'] self.embedding_model_path = config['feature_database']['embedding_model_path']
...@@ -55,12 +56,10 @@ class ChatAgent: ...@@ -55,12 +56,10 @@ class ChatAgent:
self.retriever = CacheRetriever(self.embedding_model_path, self.reranker_model_path).get(work_dir=self.work_dir) self.retriever = CacheRetriever(self.embedding_model_path, self.reranker_model_path).get(work_dir=self.work_dir)
class Worker: class Worker:
def __init__(self, config, tensor_parallel_size):
def __init__(self, config, tensor_parallel_size):
self.agent = ChatAgent(config, tensor_parallel_size) self.agent = ChatAgent(config, tensor_parallel_size)
self.TOPIC_TEMPLATE = '告诉我这句话的主题,直接说主题不要解释:“{}”' self.TOPIC_TEMPLATE = '告诉我这句话的主题,直接说主题不要解释:“{}”'
self.SCORING_RELAVANCE_TEMPLATE = '问题:“{}”\n材料:“{}”\n请仔细阅读以上内容,材料里为一个列表,列表里面有若干子列表,请判断每个子列表的内容和问题的相关度,不要解释直接给出相关度得分列表并以空格分隔,用0~10表示。判断标准:非常相关得 10 分;完全没关联得 0 分。\n' # noqa E501 self.SCORING_RELAVANCE_TEMPLATE = '问题:“{}”\n材料:“{}”\n请仔细阅读以上内容,材料里为一个列表,列表里面有若干子列表,请判断每个子列表的内容和问题的相关度,不要解释直接给出相关度得分列表并以空格分隔,用0~10表示。判断标准:非常相关得 10 分;完全没关联得 0 分。\n' # noqa E501
self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。' # noqa E501 self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。' # noqa E501
...@@ -110,7 +109,6 @@ class Worker: ...@@ -110,7 +109,6 @@ class Worker:
response_direct = self.agent.call_llm_response(prompt=prompt) response_direct = self.agent.call_llm_response(prompt=prompt)
return ErrorCode.NOT_FIND_RELATED_DOCS, response_direct, None return ErrorCode.NOT_FIND_RELATED_DOCS, response_direct, None
def produce_response(self, query, def produce_response(self, query,
history, history,
judgment, judgment,
......
...@@ -6,36 +6,37 @@ from loguru import logger ...@@ -6,36 +6,37 @@ from loguru import logger
from llm_service import Worker, llm_inference from llm_service import Worker, llm_inference
def check_envs(args): def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
def parse_args(): def parse_args():
"""Parse args.""" """Parse args."""
parser = argparse.ArgumentParser(description='Executor.') parser = argparse.ArgumentParser(description='Executor.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
default=[1,2,6,7], type=str,
help='设置DCU') default='0',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='/path/to/your/ai/config.ini', default='/path/of/config.ini',
type=str, type=str,
help='config.ini路径') help='config.ini路径')
parser.add_argument( parser.add_argument(
'--standalone', '--standalone',
default=False, default=False,
help='部署LLM推理服务.') help='部署LLM推理服务')
parser.add_argument( parser.add_argument(
'--use_vllm', '--use_vllm',
default=False, default=False,
type=bool, type=bool,
help='LLM推理是否启用加速' help='是否启用LLM推理加速'
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -54,7 +55,7 @@ def build_reply_text(reply: str, references: list): ...@@ -54,7 +55,7 @@ def build_reply_text(reply: str, references: list):
def reply_workflow(assistant): def reply_workflow(assistant):
queries = ['你好,我们公司想要购买几台测试机,请问需要联系贵公司哪位?'] queries = ['我们公司想要购买几台测试机,请问需要联系哪位?']
for query in queries: for query in queries:
code, reply, references = assistant.produce_response(query=query, code, reply, references = assistant.produce_response(query=query,
history=[], history=[],
...@@ -66,7 +67,7 @@ def run(): ...@@ -66,7 +67,7 @@ def run():
args = parse_args() args = parse_args()
if args.standalone is True: if args.standalone is True:
import time import time
check_envs(args) set_envs(args)
server_ready = Value('i', 0) server_ready = Value('i', 0)
server_process = Process(target=llm_inference, server_process = Process(target=llm_inference,
args=(args.config_path, args=(args.config_path,
...@@ -78,7 +79,7 @@ def run(): ...@@ -78,7 +79,7 @@ def run():
server_process.start() server_process.start()
while True: while True:
if server_ready.value == 0: if server_ready.value == 0:
logger.info('waiting for server to be ready..') logger.info('waiting for server to be ready.')
time.sleep(15) time.sleep(15)
elif server_ready.value == 1: elif server_ready.value == 1:
break break
......
...@@ -4,7 +4,6 @@ from loguru import logger ...@@ -4,7 +4,6 @@ from loguru import logger
import argparse import argparse
def start(query): def start(query):
url = 'http://127.0.0.1:8888/work' url = 'http://127.0.0.1:8888/work'
try: try:
header = {'Content-Type': 'application/json'} header = {'Content-Type': 'application/json'}
...@@ -27,7 +26,7 @@ def start(query): ...@@ -27,7 +26,7 @@ def start(query):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='.') parser = argparse.ArgumentParser(description='.')
parser.add_argument('--query', parser.add_argument('--query',
default='your query', default='输入用户问题',
help='') help='')
return parser.parse_args() return parser.parse_args()
...@@ -36,4 +35,4 @@ if __name__ == '__main__': ...@@ -36,4 +35,4 @@ if __name__ == '__main__':
args = parse_args() args = parse_args()
reply, ref = start(args.query) reply, ref = start(args.query)
logger.debug('reply: {} \nref: {} '.format(reply, logger.debug('reply: {} \nref: {} '.format(reply,
ref)) ref))
\ No newline at end of file
...@@ -25,7 +25,6 @@ def workflow(args): ...@@ -25,7 +25,6 @@ def workflow(args):
raise (e) raise (e)
async def work(request): async def work(request):
input_json = await request.json() input_json = await request.json()
query = input_json['query'] query = input_json['query']
...@@ -117,16 +116,17 @@ def auto_select_dcu(config): ...@@ -117,16 +116,17 @@ def auto_select_dcu(config):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Start all services.') parser = argparse.ArgumentParser(description='Start all services.')
parser.add_argument('--config_path', parser.add_argument('--config_path',
default='ai/config.ini', default='/path/of/config.ini',
help='Config directory') help='Config directory')
parser.add_argument('--log_path', parser.add_argument('--log_path',
default='', default='',
help='Set log file path') help='Set log file path')
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
log_path = '/var/log/assistant.log' log_path = './log/assistant.log'
if args.log_path: if args.log_path:
log_path = args.log_path log_path = args.log_path
logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True) logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True)
...@@ -134,4 +134,4 @@ def main(): ...@@ -134,4 +134,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment