Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
fe3bae99
Commit
fe3bae99
authored
Jul 15, 2024
by
Rayyyyy
Browse files
update
parent
de6f9f97
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
51 additions
and
58 deletions
+51
-58
config.ini
config.ini
+9
-9
exp/test_reject_thr.py
exp/test_reject_thr.py
+0
-1
llm_service/feature_database.py
llm_service/feature_database.py
+2
-6
llm_service/helper.py
llm_service/helper.py
+1
-1
llm_service/inferencer.py
llm_service/inferencer.py
+1
-1
llm_service/retriever.py
llm_service/retriever.py
+14
-14
llm_service/worker.py
llm_service/worker.py
+2
-4
main.py
main.py
+16
-15
server.py
server.py
+2
-3
server_start.py
server_start.py
+4
-4
No files found.
config.ini
View file @
fe3bae99
[default]
work_dir
=
/path/to/your/ai/work_dir
bind_port
=
8888
mem_threshold
=
50
dcu_threshold
=
100
work_dir
=
/path/to/your/ai/work_dir
bind_port
=
8888
mem_threshold
=
50
dcu_threshold
=
100
[feature_database]
reject_throttle
=
0.6165309870679363
embedding_model_path
=
/path/to/your/text2vec-large-chinese
reranker_model_path
=
/path/to/your/bce-reranker-base_v1
reject_throttle
=
0.6165309870679363
embedding_model_path
=
/path/to/your/text2vec-large-chinese
reranker_model_path
=
/path/to/your/bce-reranker-base_v1
[llm]
local_llm_path
=
/path/to/your/internlm-chat-7b
use_vllm
=
False
\ No newline at end of file
local_llm_path
=
/path/to/your/internlm-chat-7b
use_vllm
=
False
\ No newline at end of file
exp/test_reject_thr.py
View file @
fe3bae99
...
...
@@ -56,4 +56,3 @@ plt.legend()
# 显示图表
plt
.
show
()
llm_service/feature_database.py
View file @
fe3bae99
...
...
@@ -470,12 +470,8 @@ def parse_args():
help
=
'需要读取的文件目录.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'/
ai
/config.ini'
,
default
=
'/
home/AI_project/chat_demo
/config.ini'
,
help
=
'config目录'
)
parser
.
add_argument
(
'--DCU_ID'
,
default
=
[
4
],
help
=
'设置DCU'
)
args
=
parser
.
parse_args
()
return
args
...
...
llm_service/helper.py
View file @
fe3bae99
llm_service/inferencer.py
View file @
fe3bae99
...
...
@@ -161,7 +161,7 @@ def parse_args():
description
=
'Feature store for processing directories.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'/
home/zhangwq/project/shu_new/ai
/config.ini'
,
default
=
'/
path/of
/config.ini'
,
help
=
'config目录'
)
parser
.
add_argument
(
'--query'
,
...
...
llm_service/retriever.py
View file @
fe3bae99
...
...
@@ -15,15 +15,6 @@ from sklearn.metrics import precision_recall_curve
from
loguru
import
logger
def
check_envs
(
args
):
if
all
(
isinstance
(
item
,
int
)
for
item
in
args
.
DCU_ID
):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
','
.
join
(
map
(
str
,
args
.
DCU_ID
))
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
args
.
DCU_ID
}
"
)
else
:
logger
.
error
(
f
"The --DCU_ID argument must be a list of integers, but got
{
args
.
DCU_ID
}
"
)
raise
ValueError
(
"The --DCU_ID argument must be a list of integers"
)
class
Retriever
:
def
__init__
(
self
,
embeddings
,
reranker
,
work_dir
:
str
,
reject_throttle
:
float
)
->
None
:
self
.
reject_throttle
=
reject_throttle
...
...
@@ -304,12 +295,21 @@ def test_query(retriever: Retriever, real_questions):
empty_cache
()
def
set_envs
(
dcu_ids
):
try
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
dcu_ids
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
dcu_ids
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"
{
e
}
, but got
{
dcu_ids
}
"
)
raise
ValueError
(
f
"
{
e
}
"
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Feature store for processing directories.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'/
home/zhangwq/project/shu/ai
/config.ini'
,
default
=
'/
path/of
/config.ini'
,
help
=
'config目录'
)
parser
.
add_argument
(
'--query'
,
...
...
@@ -317,15 +317,16 @@ def parse_args():
help
=
'提问的问题.'
)
parser
.
add_argument
(
'--DCU_ID'
,
default
=
[
6
],
help
=
'设置DCU'
)
type
=
str
,
default
=
'0'
,
help
=
'设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
check
_envs
(
args
)
set
_envs
(
args
)
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
...
...
@@ -345,4 +346,3 @@ def main():
if
__name__
==
'__main__'
:
main
()
llm_service/worker.py
View file @
fe3bae99
...
...
@@ -6,6 +6,7 @@ from .feature_database import DocumentProcessor, FeatureDataBase
class
ChatAgent
:
def
__init__
(
self
,
config
,
tensor_parallel_size
)
->
None
:
self
.
work_dir
=
config
[
'default'
][
'work_dir'
]
self
.
embedding_model_path
=
config
[
'feature_database'
][
'embedding_model_path'
]
...
...
@@ -55,12 +56,10 @@ class ChatAgent:
self
.
retriever
=
CacheRetriever
(
self
.
embedding_model_path
,
self
.
reranker_model_path
).
get
(
work_dir
=
self
.
work_dir
)
class
Worker
:
def
__init__
(
self
,
config
,
tensor_parallel_size
):
def
__init__
(
self
,
config
,
tensor_parallel_size
):
self
.
agent
=
ChatAgent
(
config
,
tensor_parallel_size
)
self
.
TOPIC_TEMPLATE
=
'告诉我这句话的主题,直接说主题不要解释:“{}”'
self
.
SCORING_RELAVANCE_TEMPLATE
=
'问题:“{}”
\n
材料:“{}”
\n
请仔细阅读以上内容,材料里为一个列表,列表里面有若干子列表,请判断每个子列表的内容和问题的相关度,不要解释直接给出相关度得分列表并以空格分隔,用0~10表示。判断标准:非常相关得 10 分;完全没关联得 0 分。
\n
'
# noqa E501
self
.
KEYWORDS_TEMPLATE
=
'谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。
\n
你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。'
# noqa E501
...
...
@@ -110,7 +109,6 @@ class Worker:
response_direct
=
self
.
agent
.
call_llm_response
(
prompt
=
prompt
)
return
ErrorCode
.
NOT_FIND_RELATED_DOCS
,
response_direct
,
None
def
produce_response
(
self
,
query
,
history
,
judgment
,
...
...
main.py
View file @
fe3bae99
...
...
@@ -6,36 +6,37 @@ from loguru import logger
from
llm_service
import
Worker
,
llm_inference
def
check_envs
(
args
):
def
set_envs
(
dcu_ids
):
try
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
dcu_ids
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
dcu_ids
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"
{
e
}
, but got
{
dcu_ids
}
"
)
raise
ValueError
(
f
"
{
e
}
"
)
if
all
(
isinstance
(
item
,
int
)
for
item
in
args
.
DCU_ID
):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
','
.
join
(
map
(
str
,
args
.
DCU_ID
))
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
args
.
DCU_ID
}
"
)
else
:
logger
.
error
(
f
"The --DCU_ID argument must be a list of integers, but got
{
args
.
DCU_ID
}
"
)
raise
ValueError
(
"The --DCU_ID argument must be a list of integers"
)
def
parse_args
():
"""Parse args."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Executor.'
)
parser
.
add_argument
(
'--DCU_ID'
,
default
=
[
1
,
2
,
6
,
7
],
help
=
'设置DCU'
)
type
=
str
,
default
=
'0'
,
help
=
'设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'/path/
to/your/ai
/config.ini'
,
default
=
'/path/
of
/config.ini'
,
type
=
str
,
help
=
'config.ini路径'
)
parser
.
add_argument
(
'--standalone'
,
default
=
False
,
help
=
'部署LLM推理服务
.
'
)
help
=
'部署LLM推理服务'
)
parser
.
add_argument
(
'--use_vllm'
,
default
=
False
,
type
=
bool
,
help
=
'LLM推理
是否启用
加速'
help
=
'
是否启用
LLM推理加速'
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -54,7 +55,7 @@ def build_reply_text(reply: str, references: list):
def
reply_workflow
(
assistant
):
queries
=
[
'
你好,
我们公司想要购买几台测试机,请问需要联系
贵公司
哪位?'
]
queries
=
[
'我们公司想要购买几台测试机,请问需要联系哪位?'
]
for
query
in
queries
:
code
,
reply
,
references
=
assistant
.
produce_response
(
query
=
query
,
history
=
[],
...
...
@@ -66,7 +67,7 @@ def run():
args
=
parse_args
()
if
args
.
standalone
is
True
:
import
time
check
_envs
(
args
)
set
_envs
(
args
)
server_ready
=
Value
(
'i'
,
0
)
server_process
=
Process
(
target
=
llm_inference
,
args
=
(
args
.
config_path
,
...
...
@@ -78,7 +79,7 @@ def run():
server_process
.
start
()
while
True
:
if
server_ready
.
value
==
0
:
logger
.
info
(
'waiting for server to be ready.
.
'
)
logger
.
info
(
'waiting for server to be ready.'
)
time
.
sleep
(
15
)
elif
server_ready
.
value
==
1
:
break
...
...
server.py
View file @
fe3bae99
...
...
@@ -4,7 +4,6 @@ from loguru import logger
import
argparse
def
start
(
query
):
url
=
'http://127.0.0.1:8888/work'
try
:
header
=
{
'Content-Type'
:
'application/json'
}
...
...
@@ -27,7 +26,7 @@ def start(query):
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'.'
)
parser
.
add_argument
(
'--query'
,
default
=
'
your query
'
,
default
=
'
输入用户问题
'
,
help
=
''
)
return
parser
.
parse_args
()
...
...
server_start.py
View file @
fe3bae99
...
...
@@ -25,7 +25,6 @@ def workflow(args):
raise
(
e
)
async
def
work
(
request
):
input_json
=
await
request
.
json
()
query
=
input_json
[
'query'
]
...
...
@@ -117,16 +116,17 @@ def auto_select_dcu(config):
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Start all services.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'
ai
/config.ini'
,
default
=
'
/path/of
/config.ini'
,
help
=
'Config directory'
)
parser
.
add_argument
(
'--log_path'
,
default
=
''
,
help
=
'Set log file path'
)
return
parser
.
parse_args
()
def
main
():
args
=
parse_args
()
log_path
=
'
/var
/log/assistant.log'
log_path
=
'
.
/log/assistant.log'
if
args
.
log_path
:
log_path
=
args
.
log_path
logger
.
add
(
sink
=
log_path
,
level
=
"DEBUG"
,
rotation
=
"500MB"
,
compression
=
"zip"
,
encoding
=
"utf-8"
,
enqueue
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment