Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
cf14b883
Commit
cf14b883
authored
Sep 14, 2024
by
chenych
Browse files
update
parent
6088e14e
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
39 additions
and
648 deletions
+39
-648
client.py
client.py
+1
-1
config.ini
config.ini
+7
-7
llm_service/__init__.py
llm_service/__init__.py
+2
-2
llm_service/feature_database.py
llm_service/feature_database.py
+0
-526
llm_service/http_client.py
llm_service/http_client.py
+2
-3
llm_service/retriever.py
llm_service/retriever.py
+1
-4
llm_service/worker.py
llm_service/worker.py
+12
-45
rag/feature_database.py
rag/feature_database.py
+7
-53
server_start.py
server_start.py
+7
-7
No files found.
client.py
View file @
cf14b883
...
@@ -51,7 +51,7 @@ def get_streaming_response(response: requests.Response):
...
@@ -51,7 +51,7 @@ def get_streaming_response(response: requests.Response):
print
(
char
,
end
=
""
,
flush
=
True
)
print
(
char
,
end
=
""
,
flush
=
True
)
def
stream_query
(
query
):
def
stream_query
(
query
,
user_id
=
None
):
url
=
base_url
%
'stream'
url
=
base_url
%
'stream'
try
:
try
:
...
...
config.ini
View file @
cf14b883
[default]
[default]
work_dir
=
/
path/to/your/ai/work_dir
work_dir
=
/
home/
bind_port
=
8000
bind_port
=
8000
use_template
=
False
use_template
=
False
output_format
=
True
output_format
=
True
[feature_database]
[feature_database]
reject_throttle
=
0.61
reject_throttle
=
0.61
embedding_model_path
=
/path/to/your
/text2vec-large-chinese
embedding_model_path
=
/home/Embedding_model
/text2vec-large-chinese
reranker_model_path
=
/path/to/your
/bce-reranker-base_v1
reranker_model_path
=
/home/Embedding_model
/bce-reranker-base_v1
[model]
[model]
llm_service_address
=
http://127.0.0.1:8001
llm_service_address
=
http://127.0.0.1:8001
local_service_address
=
http://127.0.0.1:8002
local_service_address
=
http://127.0.0.1:8002
cls_model_path
=
/
path/of/classification
cls_model_path
=
/
home/llm_model/bert-base-chinese
llm_model
=
/
path/to/your
/Llama3-8B-Chinese-Chat/
llm_model
=
/
home/llm_model
/Llama3
.1
-8B-Chinese-Chat/
local_model
=
/
path/to/your/Finetune/
local_model
=
/
home/llm_model/llama3-jifu-0717-1024
max_input_length
=
1400
max_input_length
=
1400
\ No newline at end of file
llm_service/__init__.py
View file @
cf14b883
from
.feature_database
import
FeatureDataBase
,
DocumentProcessor
,
DocumentName
# noqa E401
from
rag
.feature_database
import
FeatureDataBase
,
DocumentProcessor
,
DocumentName
# noqa E401
from
.helper
import
TaskCode
,
ErrorCode
,
LogManager
# noqa E401
from
.helper
import
TaskCode
,
ErrorCode
,
LogManager
# noqa E401
from
.http_client
import
OpenAPIClient
,
Classify
Client
# noqa E401
from
.http_client
import
OpenAPIClient
,
Classify
Model
# noqa E401
from
.worker
import
Worker
# noqa E401
from
.worker
import
Worker
# noqa E401
\ No newline at end of file
llm_service/feature_database.py
deleted
100644 → 0
View file @
6088e14e
import
argparse
import
fitz
import
re
import
os
import
time
import
pandas
as
pd
import
hashlib
import
textract
import
shutil
import
configparser
from
multiprocessing
import
Pool
from
typing
import
List
from
loguru
import
logger
from
BCEmbedding.tools.langchain
import
BCERerank
from
langchain.embeddings
import
HuggingFaceEmbeddings
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.vectorstores.faiss
import
FAISS
from
torch.cuda
import
empty_cache
from
bs4
import
BeautifulSoup
from
elastic_keywords_search
import
ElasticKeywordsSearch
from
retriever
import
Retriever
def
check_envs
(
args
):
if
all
(
isinstance
(
item
,
int
)
for
item
in
args
.
DCU_ID
):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
','
.
join
(
map
(
str
,
args
.
DCU_ID
))
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
args
.
DCU_ID
}
"
)
else
:
logger
.
error
(
f
"The --DCU_ID argument must be a list of integers, but got
{
args
.
DCU_ID
}
"
)
raise
ValueError
(
"The --DCU_ID argument must be a list of integers"
)
class
DocumentName
:
def
__init__
(
self
,
directory
:
str
,
name
:
str
,
category
:
str
):
self
.
directory
=
directory
self
.
prefix
=
name
.
replace
(
'/'
,
'_'
)
self
.
basename
=
os
.
path
.
basename
(
name
)
self
.
origin_path
=
os
.
path
.
join
(
directory
,
name
)
self
.
copy_path
=
''
self
.
_category
=
category
self
.
status
=
True
self
.
message
=
''
def
__str__
(
self
):
return
'{},{},{},{}
\n
'
.
format
(
self
.
basename
,
self
.
copy_path
,
self
.
status
,
self
.
message
)
class
DocumentProcessor
:
def
__init__
(
self
):
self
.
image_suffix
=
[
'.jpg'
,
'.jpeg'
,
'.png'
,
'.bmp'
]
self
.
md_suffix
=
'.md'
self
.
text_suffix
=
[
'.txt'
,
'.text'
]
self
.
excel_suffix
=
[
'.xlsx'
,
'.xls'
,
'.csv'
]
self
.
pdf_suffix
=
'.pdf'
self
.
ppt_suffix
=
'.pptx'
self
.
html_suffix
=
[
'.html'
,
'.htm'
,
'.shtml'
,
'.xhtml'
]
self
.
word_suffix
=
[
'.docx'
,
'.doc'
]
self
.
json_suffix
=
'.json'
def
md5
(
self
,
filepath
:
str
):
hash_object
=
hashlib
.
sha256
()
with
open
(
filepath
,
'rb'
)
as
file
:
chunk_size
=
8192
while
chunk
:
=
file
.
read
(
chunk_size
):
hash_object
.
update
(
chunk
)
return
hash_object
.
hexdigest
()[
0
:
8
]
def
summarize
(
self
,
files
:
list
):
success
=
0
skip
=
0
failed
=
0
for
file
in
files
:
if
file
.
status
:
success
+=
1
elif
file
.
message
==
'skip'
:
skip
+=
1
else
:
logger
.
info
(
'{}文件异常, 异常信息: {} '
.
format
(
file
.
origin_path
,
file
.
message
))
failed
+=
1
logger
.
info
(
'解析{}文件,成功{}个,跳过{}个,异常{}个'
.
format
(
len
(
files
),
success
,
skip
,
failed
))
def
read_file_type
(
self
,
filepath
:
str
):
filepath
=
filepath
.
lower
()
if
filepath
.
endswith
(
self
.
pdf_suffix
):
return
'pdf'
if
filepath
.
endswith
(
self
.
md_suffix
):
return
'md'
if
filepath
.
endswith
(
self
.
ppt_suffix
):
return
'ppt'
if
filepath
.
endswith
(
self
.
json_suffix
):
return
'json'
for
suffix
in
self
.
image_suffix
:
if
filepath
.
endswith
(
suffix
):
return
'image'
for
suffix
in
self
.
text_suffix
:
if
filepath
.
endswith
(
suffix
):
return
'text'
for
suffix
in
self
.
word_suffix
:
if
filepath
.
endswith
(
suffix
):
return
'word'
for
suffix
in
self
.
excel_suffix
:
if
filepath
.
endswith
(
suffix
):
return
'excel'
for
suffix
in
self
.
html_suffix
:
if
filepath
.
endswith
(
suffix
):
return
'html'
return
None
def
scan_directory
(
self
,
repo_dir
:
str
):
documents
=
[]
for
directory
,
_
,
names
in
os
.
walk
(
repo_dir
):
for
name
in
names
:
category
=
self
.
read_file_type
(
name
)
if
category
is
not
None
:
documents
.
append
(
DocumentName
(
directory
=
directory
,
name
=
name
,
category
=
category
))
return
documents
def
read
(
self
,
filepath
:
str
):
file_type
=
self
.
read_file_type
(
filepath
)
text
=
''
if
not
os
.
path
.
exists
(
filepath
):
return
text
try
:
if
file_type
==
'md'
or
file_type
==
'text'
:
text
=
[]
with
open
(
filepath
)
as
f
:
txt
=
f
.
read
()
cleaned_txt
=
re
.
sub
(
r
'\n\s*\n'
,
'
\n\n
'
,
txt
)
text
.
append
(
cleaned_txt
)
elif
file_type
==
'pdf'
:
text
+=
self
.
read_pdf
(
filepath
)
text
=
re
.
sub
(
r
'\n\s*\n'
,
'
\n\n
'
,
text
)
elif
file_type
==
'excel'
:
text
+=
self
.
read_excel
(
filepath
)
elif
file_type
==
'word'
or
file_type
==
'ppt'
:
# https://stackoverflow.com/questions/36001482/read-doc-file-with-python
# https://textract.readthedocs.io/en/latest/installation.html
text
=
textract
.
process
(
filepath
).
decode
(
'utf8'
)
text
=
re
.
sub
(
r
'\n\s*\n'
,
'
\n\n
'
,
text
)
if
file_type
==
'ppt'
:
text
=
text
.
replace
(
'
\n
'
,
' '
)
elif
file_type
==
'html'
:
with
open
(
filepath
)
as
f
:
soup
=
BeautifulSoup
(
f
.
read
(),
'html.parser'
)
text
+=
soup
.
text
elif
filepath
.
endswith
(
'.json'
):
# 打开JSON文件进行读取
with
open
(
filepath
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
# 读取文件的所有行
text
=
file
.
readlines
()
except
Exception
as
e
:
logger
.
error
((
filepath
,
str
(
e
)))
return
''
,
e
return
text
,
None
def
read_excel
(
self
,
filepath
:
str
):
table
=
None
if
filepath
.
endswith
(
'.csv'
):
table
=
pd
.
read_csv
(
filepath
)
else
:
table
=
pd
.
read_excel
(
filepath
)
if
table
is
None
:
return
''
json_text
=
table
.
dropna
(
axis
=
1
).
to_json
(
force_ascii
=
False
)
return
json_text
def
read_pdf
(
self
,
filepath
:
str
):
# load pdf and serialize table
text
=
''
with
fitz
.
open
(
filepath
)
as
pages
:
for
page
in
pages
:
text
+=
page
.
get_text
()
tables
=
page
.
find_tables
()
for
table
in
tables
:
tablename
=
'_'
.
join
(
filter
(
lambda
x
:
x
is
not
None
and
'Col'
not
in
x
,
table
.
header
.
names
))
pan
=
table
.
to_pandas
()
json_text
=
pan
.
dropna
(
axis
=
1
).
to_json
(
force_ascii
=
False
)
text
+=
tablename
text
+=
'
\n
'
text
+=
json_text
text
+=
'
\n
'
return
text
def
read_and_save
(
file
:
DocumentName
,
file_opr
:
DocumentProcessor
):
try
:
if
os
.
path
.
exists
(
file
.
copy_path
):
# already exists, return
logger
.
info
(
'{} already processed, output file: {}, skip load'
.
format
(
file
.
origin_path
,
file
.
copy_path
))
return
logger
.
info
(
'reading {}, would save to {}'
.
format
(
file
.
origin_path
,
file
.
copy_path
))
content
,
error
=
file_opr
.
read
(
file
.
origin_path
)
if
error
is
not
None
:
logger
.
error
(
'{} load error: {}'
.
format
(
file
.
origin_path
,
str
(
error
)))
return
if
content
is
None
or
len
(
content
)
<
1
:
logger
.
warning
(
'{} empty, skip save'
.
format
(
file
.
origin_path
))
return
cleaned_content
=
re
.
sub
(
r
'\n\s*\n'
,
'
\n\n
'
,
content
)
with
open
(
file
.
copy_path
,
'w'
)
as
f
:
f
.
write
(
os
.
path
.
splitext
(
file
.
basename
)[
0
]
+
'
\n
'
)
f
.
write
(
cleaned_content
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in read_and_save:
{
e
}
"
)
class
FeatureDataBase
:
def
__init__
(
self
,
embeddings
:
HuggingFaceEmbeddings
,
reranker
:
BCERerank
,
reject_throttle
=-
1
)
->
None
:
# logger.debug('loading text2vec model..')
self
.
embeddings
=
embeddings
self
.
reranker
=
reranker
self
.
compression_retriever
=
None
self
.
rejecter
=
None
self
.
retriever
=
None
self
.
reject_throttle
=
reject_throttle
if
reject_throttle
else
-
1
self
.
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
1068
,
chunk_overlap
=
32
)
def
get_documents
(
self
,
text
,
file
):
# if len(text) <= 1:
# return []
chunks
=
self
.
text_splitter
.
create_documents
(
text
)
documents
=
[]
for
chunk
in
chunks
:
# `source` is for return references
# `read` is for LLM response
chunk
.
metadata
=
{
'source'
:
file
.
basename
,
'read'
:
file
.
origin_path
}
documents
.
append
(
chunk
)
return
documents
def
build_database
(
self
,
files
:
list
,
work_dir
:
str
,
file_opr
:
DocumentProcessor
,
elastic_search
=
None
):
feature_dir
=
os
.
path
.
join
(
work_dir
,
'db_response'
)
if
not
os
.
path
.
exists
(
feature_dir
):
os
.
makedirs
(
feature_dir
)
documents
=
[]
texts_for_es
=
[]
metadatas_for_es
=
[]
ids_for_es
=
[]
for
i
,
file
in
enumerate
(
files
):
if
not
file
.
status
:
continue
# 读取每个file
text
,
error
=
file_opr
.
read
(
file
.
copy_path
)
if
error
is
not
None
:
file
.
status
=
False
file
.
message
=
str
(
error
)
continue
file
.
message
=
str
(
text
[
0
])
texts_for_es
.
append
(
text
[
0
])
metadatas_for_es
.
append
({
'source'
:
file
.
basename
,
'read'
:
file
.
origin_path
})
ids_for_es
.
append
(
str
(
i
))
document
=
self
.
get_documents
(
text
,
file
)
documents
+=
document
logger
.
debug
(
'Positive pipeline {}/{}.. register 《{}》 and split {} documents'
.
format
(
i
+
1
,
len
(
files
),
file
.
basename
,
len
(
document
)))
if
elastic_search
is
not
None
:
logger
.
debug
(
'ES database pipeline register {} documents into database...'
.
format
(
len
(
texts_for_es
)))
es_time_before_register
=
time
.
time
()
elastic_search
.
add_texts
(
texts_for_es
,
metadatas
=
metadatas_for_es
,
ids
=
ids_for_es
)
es_time_after_register
=
time
.
time
()
logger
.
debug
(
'ES database pipeline take time: {} '
.
format
(
es_time_after_register
-
es_time_before_register
))
logger
.
debug
(
'Vector database pipeline register {} documents into database...'
.
format
(
len
(
documents
)))
ve_time_before_register
=
time
.
time
()
vs
=
FAISS
.
from_documents
(
documents
,
self
.
embeddings
)
vs
.
save_local
(
feature_dir
)
ve_time_after_register
=
time
.
time
()
logger
.
debug
(
'Vector database pipeline take time: {} '
.
format
(
ve_time_after_register
-
ve_time_before_register
))
def
preprocess
(
self
,
files
:
list
,
work_dir
:
str
,
file_opr
:
DocumentProcessor
):
preproc_dir
=
os
.
path
.
join
(
work_dir
,
'preprocess'
)
if
not
os
.
path
.
exists
(
preproc_dir
):
os
.
makedirs
(
preproc_dir
)
pool
=
Pool
(
processes
=
16
)
for
idx
,
file
in
enumerate
(
files
):
if
not
os
.
path
.
exists
(
file
.
origin_path
):
file
.
status
=
False
file
.
message
=
'skip not exist'
continue
if
file
.
_category
==
'image'
:
file
.
status
=
False
file
.
message
=
'skip image'
elif
file
.
_category
in
[
'pdf'
,
'word'
,
'ppt'
,
'html'
,
'excel'
]:
# read pdf/word/excel file and save to text format
md5
=
file_opr
.
md5
(
file
.
origin_path
)
file
.
copy_path
=
os
.
path
.
join
(
preproc_dir
,
'{}.text'
.
format
(
md5
))
pool
.
apply_async
(
read_and_save
,
args
=
(
file
,
file_opr
))
elif
file
.
_category
in
[
'md'
,
'text'
]:
# rename text files to new dir
file
.
copy_path
=
os
.
path
.
join
(
preproc_dir
,
file
.
origin_path
.
replace
(
'/'
,
'_'
)[
-
84
:])
try
:
shutil
.
copy
(
file
.
origin_path
,
file
.
copy_path
)
file
.
status
=
True
file
.
message
=
'preprocessed'
except
Exception
as
e
:
file
.
status
=
False
file
.
message
=
str
(
e
)
elif
file
.
_category
in
[
'json'
]:
file
.
status
=
True
file
.
copy_path
=
file
.
origin_path
file
.
message
=
'preprocessed'
else
:
file
.
status
=
False
file
.
message
=
'skip unknown format'
pool
.
close
()
logger
.
debug
(
'waiting for preprocess read finish..'
)
pool
.
join
()
# check process result
for
file
in
files
:
if
file
.
_category
in
[
'pdf'
,
'word'
,
'excel'
]:
if
os
.
path
.
exists
(
file
.
copy_path
):
file
.
status
=
True
file
.
message
=
'preprocessed'
else
:
file
.
status
=
False
file
.
message
=
'read error'
def
initialize
(
self
,
files
:
list
,
work_dir
:
str
,
file_opr
:
DocumentProcessor
,
elastic_search
=
None
):
self
.
preprocess
(
files
=
files
,
work_dir
=
work_dir
,
file_opr
=
file_opr
)
self
.
build_database
(
files
=
files
,
work_dir
=
work_dir
,
file_opr
=
file_opr
,
elastic_search
=
elastic_search
)
def
merge_db_response
(
self
,
faiss
:
FAISS
,
files
:
list
,
work_dir
:
str
,
file_opr
:
DocumentProcessor
):
feature_dir
=
os
.
path
.
join
(
work_dir
,
'db_response'
)
if
not
os
.
path
.
exists
(
feature_dir
):
os
.
makedirs
(
feature_dir
)
documents
=
[]
for
i
,
file
in
enumerate
(
files
):
logger
.
debug
(
'{}/{}.. register 《{}》 into database...'
.
format
(
i
+
1
,
len
(
files
),
file
.
basename
))
if
not
file
.
status
:
continue
# 读取每个file
text
,
error
=
file_opr
.
read
(
file
.
copy_path
)
if
error
is
not
None
:
file
.
status
=
False
file
.
message
=
str
(
error
)
continue
logger
.
info
(
str
(
len
(
text
)),
text
,
str
(
text
[
0
]))
file
.
message
=
str
(
text
[
0
])
# file.message = str(len(text))
# logger.info('{} content length {}'.format(
# file._category, len(text)))
documents
+=
self
.
get_documents
(
text
,
file
)
if
documents
:
vs
=
FAISS
.
from_documents
(
documents
,
self
.
embeddings
)
if
faiss
:
faiss
.
merge_from
(
vs
)
faiss
.
save_local
(
feature_dir
)
else
:
vs
.
save_local
(
feature_dir
)
def
test_reject
(
retriever
:
Retriever
):
"""Simple test reject pipeline."""
real_questions
=
[
'姚明是谁?'
,
'CBBA是啥?'
,
'差多少嘞?'
,
'cnn 的全称是什么?'
,
'transformer啥意思?'
,
'成都有什么好吃的推荐?'
,
'树博士是什么?'
,
'白马非马啥意思?'
,
'mmpose 如何安装?'
,
'今天天气如何?'
,
'写一首五言律诗?'
,
'先有鸡还是先有蛋?'
,
'如何在Gromacs中进行蛋白质的动态模拟?'
,
'wy-vSphere 7 海光平台兼容补丁?'
,
'在Linux系统中,如何进行源码包的安装?'
]
for
example
in
real_questions
:
relative
,
_
=
retriever
.
is_relative
(
example
)
if
relative
:
logger
.
warning
(
f
'process query:
{
example
}
'
)
retriever
.
query
(
example
)
empty_cache
()
else
:
logger
.
error
(
f
'reject query:
{
example
}
'
)
empty_cache
()
def
parse_args
():
"""Parse command-line arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Feature store for processing directories.'
)
parser
.
add_argument
(
'--work_dir'
,
type
=
str
,
default
=
''
,
help
=
'自定义.'
)
parser
.
add_argument
(
'--repo_dir'
,
type
=
str
,
default
=
''
,
help
=
'需要读取的文件目录.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'./ai/rag/config.ini'
,
help
=
'config目录'
)
parser
.
add_argument
(
'--DCU_ID'
,
default
=
[
7
],
help
=
'设置DCU'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
parse_args
()
log_file_path
=
os
.
path
.
join
(
args
.
work_dir
,
'application.log'
)
logger
.
add
(
log_file_path
,
rotation
=
'10MB'
,
compression
=
'zip'
)
check_envs
(
args
)
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
# only init vector retriever
retriever
=
Retriever
(
config
)
fs_init
=
FeatureDataBase
(
embeddings
=
retriever
.
embeddings
,
reranker
=
retriever
.
reranker
)
# init es retriever, drop_old means build new one or updata the 'index_name'
es_url
=
config
.
get
(
'rag'
,
'es_url'
)
index_name
=
config
.
get
(
'rag'
,
'index_name'
)
elastic_search
=
ElasticKeywordsSearch
(
elasticsearch_url
=
es_url
,
index_name
=
index_name
,
drop_old
=
True
)
# walk all files in repo dir
file_opr
=
DocumentProcessor
()
files
=
file_opr
.
scan_directory
(
repo_dir
=
args
.
repo_dir
)
fs_init
.
initialize
(
files
=
files
,
work_dir
=
args
.
work_dir
,
file_opr
=
file_opr
,
elastic_search
=
elastic_search
)
file_opr
.
summarize
(
files
)
# del fs_init
# with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f:
# positive_sample = json.load(f)
# with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f:
# negative_sample = json.load(f)
#
# with open(os.path.join(args.work_dir, 'sample', 'positive.txt'), 'r', encoding='utf-8') as file:
# positive_sample = []
# for line in file:
# positive_sample.append(line.strip())
#
# with open(os.path.join(args.work_dir, 'sample', 'negative.txt'), 'r', encoding='utf-8') as file:
# negative_sample = []
# for line in file:
# negative_sample.append(line.strip())
#
# test_reject(retriever)
llm_service/http_client.py
View file @
cf14b883
...
@@ -76,14 +76,13 @@ class OpenAPIClient:
...
@@ -76,14 +76,13 @@ class OpenAPIClient:
class
ClassifyModel
:
class
ClassifyModel
:
def
__init__
(
self
,
model_path
,
dcu_id
):
def
__init__
(
self
,
model_path
,):
logger
.
info
(
"Starting initial bert class model"
)
logger
.
info
(
"Starting initial bert class model"
)
self
.
cls_model
=
BertForSequenceClassification
.
from_pretrained
(
model_path
).
float
().
cuda
()
self
.
cls_model
=
BertForSequenceClassification
.
from_pretrained
(
model_path
).
float
().
cuda
()
self
.
cls_model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
'bert_cls_model.pth'
)))
self
.
cls_model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
'bert_cls_model.pth'
)))
self
.
cls_model
.
eval
()
self
.
cls_model
.
eval
()
self
.
cls_tokenizer
=
BertTokenizer
.
from_pretrained
(
model_path
)
self
.
cls_tokenizer
=
BertTokenizer
.
from_pretrained
(
model_path
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
dcu_id
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
dcu_id
}
"
)
def
classfication
(
self
,
sentence
):
def
classfication
(
self
,
sentence
):
inputs
=
self
.
cls_tokenizer
(
inputs
=
self
.
cls_tokenizer
(
...
...
llm_service/retriever.py
View file @
cf14b883
...
@@ -3,8 +3,8 @@ import argparse
...
@@ -3,8 +3,8 @@ import argparse
import
time
import
time
import
configparser
import
configparser
import
numpy
as
np
import
numpy
as
np
from
aiohttp
import
web
from
aiohttp
import
web
from
multiprocessing
import
Value
from
torch.cuda
import
empty_cache
from
torch.cuda
import
empty_cache
from
BCEmbedding.tools.langchain
import
BCERerank
from
BCEmbedding.tools.langchain
import
BCERerank
from
langchain_community.embeddings
import
HuggingFaceEmbeddings
from
langchain_community.embeddings
import
HuggingFaceEmbeddings
...
@@ -339,9 +339,6 @@ def main():
...
@@ -339,9 +339,6 @@ def main():
retriever
=
cache
.
get
(
reject_throttle
=
float
(
config
[
'feature_database'
][
'reject_throttle'
]),
retriever
=
cache
.
get
(
reject_throttle
=
float
(
config
[
'feature_database'
][
'reject_throttle'
]),
work_dir
=
config
[
'default'
][
'work_dir'
])
work_dir
=
config
[
'default'
][
'work_dir'
])
test_query
(
retriever
,
args
.
query
)
test_query
(
retriever
,
args
.
query
)
# server_ready = Value('i', 0)
# rag_retrieve(config_path=args.config_path,
# server_ready=server_ready)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
llm_service/worker.py
View file @
cf14b883
import
time
import
json
import
os
import
os
import
pickle
import
pickle
from
loguru
import
logger
from
loguru
import
logger
from
.utils
import
COMMON
from
.helper
import
ErrorCode
from
.helper
import
ErrorCode
from
.http_client
import
OpenAPIClient
,
ClassifyModel
,
CacheRetriever
from
.http_client
import
OpenAPIClient
,
ClassifyModel
,
CacheRetriever
...
@@ -10,47 +12,6 @@ SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、
...
@@ -10,47 +12,6 @@ SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、
GENERATE_TEMPLATE
=
'<Data>{}</Data>
\n
回答要求:
\n
如果你不清楚答案,你需要澄清。
\n
避免提及你是从 <Data></Data> 获取的知识。
\n
保持答案与 <Data></Data> 中描述的一致。
\n
使用 Markdown 语法优化回答格式。
\n
使用与问题相同的语言回答。问题:"{}"'
GENERATE_TEMPLATE
=
'<Data>{}</Data>
\n
回答要求:
\n
如果你不清楚答案,你需要澄清。
\n
避免提及你是从 <Data></Data> 获取的知识。
\n
保持答案与 <Data></Data> 中描述的一致。
\n
使用 Markdown 语法优化回答格式。
\n
使用与问题相同的语言回答。问题:"{}"'
MARKDOWN_TEMPLATE
=
'问题:“{}”
\n
请使用markdown格式回答此问题'
MARKDOWN_TEMPLATE
=
'问题:“{}”
\n
请使用markdown格式回答此问题'
COMMON
=
{
"<光合组织登记网址>"
:
"https://www.hieco.com.cn/partner?from=timeline"
,
"<官网>"
:
"https://www.sugon.com/after_sale/policy?sh=1"
,
"<平台联系方式>"
:
"1、访问官网,根据您所在地地址联系平台人员,网址地址:https://www.sugon.com/about/contact;
\n
2、点击人工客服进行咨询;
\n
3、请您拨打中科曙光服务热线400-810-0466联系人工进行咨询。"
,
"<购买与维修的咨询方法>"
:
"1、确定付费处理,可以微信搜索'sugon中科曙光服务'小程序,选择'在线报修'业务
\n
2、先了解价格,可以微信搜索'sugon中科曙光服务'小程序,选择'其他咨询'业务
\n
3、请您拨打中科曙光服务热线400-810-0466"
,
"<服务器续保流程>"
:
"1、微信搜索'sugon中科曙光服务'小程序,选择'延保与登记'业务
\n
2、点击人工客服进行登记
\n
3、请您拨打中科曙光服务热线400-810-0466根据语音提示选择维保与购买"
,
"<XC内外网OS网盘链接>"
:
"【腾讯文档】XC内外网OS网盘链接:https://docs.qq.com/sheet/DTWtXbU1BZHJvWkJm"
,
"<W360-G30机器,安装Win7使用的镜像链接>"
:
"W360-G30机器,安装Win7使用的镜像链接:https://pan.baidu.com/s/1SjHqCP6kJ9KzdJEBZDEynw;提取码:x6m4"
,
"<麒麟系统搜狗输入法下载链接>"
:
"软件下载链接(百度云盘):链接:https://pan.baidu.com/s/18Iluvs4BOAfFET0yFMBeLQ,提取码:bhkf"
,
"<X660 G45 GPU服务器拆解视频网盘链接>"
:
"链接: https://pan.baidu.com/s/1RkRGh4XY1T2oYftGnjLp4w;提取码: v2qi"
,
"<DS800,SANTRICITY存储IBM版本模拟器网盘链接>"
:
"链接:https://pan.baidu.com/s/1euG9HGbPfrVbThEB8BX76g;提取码:o2ya"
,
"<E80-D312(X680-G55)风冷整机组装说明下载链接>"
:
"链接:https://pan.baidu.com/s/17KDpm-Z9lp01WGp9sQaQ4w;提取码:0802"
,
"<X680 G55 风冷相关资料下载链接>"
:
"链接:https://pan.baidu.com/s/1KQ-hxUIbTWNkc0xzrEQLjg;提取码:0802"
,
"<R620 G51刷写EEPROM下载>"
:
"下载链接如下:http://10.2.68.104/tools/bytedance/eeprom/"
,
"<X7450A0服务器售后培训文件网盘链接>"
:
"网盘下载:https://pan.baidu.com/s/1tZJIf_IeQLOWsvuOawhslQ?pwd=kgf1;提取码:kgf1"
,
"<福昕阅读器补丁链接>"
:
"补丁链接: https://pan.baidu.com/s/1QJQ1kHRplhhFly-vxJquFQ,提取码: aupx1"
,
"<W330-H35A_22DB4/W3335HA安装win7网盘链接>"
:
"硬盘链接: https://pan.baidu.com/s/1fDdGPH15mXiw0J-fMmLt6Q提取码: k97i"
,
"<X680 G55服务器售后培训资料网盘链接>"
:
"云盘连接下载:链接:https://pan.baidu.com/s/1gaok13DvNddtkmk6Q-qLYg?pwd=xyhb提取码:xyhb"
,
"<展厅管理员>"
:
"北京-穆淑娟18001053012
\n
天津-马书跃15720934870
\n
昆山-关天琪15304169908
\n
成都-贾小芳18613216313
\n
重庆-李子艺17347743273
\n
安阳-郭永军15824623085
\n
桐乡-李梦瑶18086537055
\n
青岛-陶祉伊15318733259"
,
"<线上预约展厅>"
:
"北京、天津、昆山、成都、重庆、安阳、桐乡、青岛"
,
"<马华>"
:
"联系人:马华,电话:13761751980,邮箱:china@pinbang.com"
,
"<梁静>"
:
"联系人:梁静,电话:18917566297,邮箱:ing.liang@omaten.com"
,
"<徐斌>"
:
"联系人:徐斌,电话:13671166044,邮箱:244898943@qq.com"
,
"<俞晓枫>"
:
"联系人:俞晓枫,电话13750869272,邮箱:857233013@qq.com"
,
"<刘广鹏>"
:
"联系人:刘广鹏,电话13321992411,邮箱:liuguangpeng@pinbang.com"
,
"<马英伟>"
:
"联系人:马英伟,电话:13260021849,邮箱:13260021849@163.com"
,
"<杨洋>"
:
"联系人:杨洋,电话15801203938,邮箱bing523888@163.com"
,
"<展会合规要求>"
:
"1.展品内容:展品内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
2.文字材料内容:文字材料内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
3.展品标签:展品标签内容需符合公司合规要求。
\n
4.礼品内容:礼品内容需符合公司合规要求。
\n
5.视频内容:视频内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
6.讲解词内容:讲解词内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
7.现场发放材料:现场发放的材料内容需符合公司合规要求。
\n
8.展示内容:整体展示内容需要经过法务合规审查。"
,
"<展会质量>"
:
"1.了解展会的组织者背景、往届展会的评价以及提供的服务支持,确保展会的专业性和高效性。
\n
.了解展会的规模、参观人数、行业影响力等因素,以判断展会是否能够提供足够的曝光度和商机。
\n
3.关注同行业其他竞争对手是否参展,以及他们的展位布置、展示内容等信息,以便制定自己的参展策略。
\n
4.展会的日期是否与公司的其他重要活动冲突,以及举办地点是否便于客户和合作伙伴的参观。
\n
5.销售部门会询问展会方提供的宣传渠道和推广服务,以及如何利用这些资源来提升公司及产品的知名度。
\n
6.记录展会期间的重要领导参观、商机线索、合作洽谈、公司拜访预约等信息,跟进后续商业机会。"
,
"<摊位费规则>"
:
"根据展位面积大小,支付相应费用。
\n
展位照明费:支付展位内的照明服务费。
\n
展位保安费:支付展位内的保安服务费。
\n
展位网络使用费:支付展位内网络使用的费用。
\n
展位电源使用费:支付展位内电源使用的费用。"
,
"<展会主题要求>"
:
"展会主题的确定需要符合公司产品和服务业务范围,以确保能够吸引目标客户群体。因此,确定展会主题时,需要考虑以下因素:
\n
专业性:展会的主题应确保专业性,符合行业特点和目标客户的需求。
\n
目标客户群体:展会的主题定位应考虑目标客户群体,确保能够吸引他们的兴趣。
\n
业务重点:展会的主题应突出公司的业务重点和优势,以便更好地推广公司的核心产品或服务。
\n
行业影响力:展会的主题定位需要考虑行业的最新发展趋势,以凸显公司的行业地位和影响力。
\n
往届展会经验:可以参考往届展会的主题定位,总结经验教训,以确定本届展会的主题。
\n
市场部意见:在确定展会主题时,应听取市场部的意见,确保主题符合公司的整体市场战略。
\n
领导意见:还需要考虑公司领导的意见,以确保展会主题符合公司的战略发展方向。"
,
"<办理展商证注意事项>"
:
"人员范围:除公司领导和同事需要办理展商证外,展会运营工作人员也需要办理。
\n
提前准备:展商证的办理需要提前进行,以确保摄影师、摄像师等工作人员可以提前入场进行布置。
\n
办理流程:需要熟悉展商证的办理流程,准备好相关材料,如身份证件等。
\n
数量需求:需要评估所需的展商证数量,避免数量不足或过多的情况。
\n
有效期限:展商证的有效期限需要注意,避免在展期内过期。
\n
存放安全:办理完的展商证需要妥善保管,避免丢失或被他人使用。
\n
使用规范:使用展商证时需要遵守展会相关规定,不得转让给他人使用。
\n
回收处理:展会结束后,需要及时回收展商证,避免泄露相关信息。"
,
"<项目单价要求>"
:
"请注意:无论是否年框供应商,项目单价都不得超过采购部制定的“2024常见活动项目标准单价”,此报价仅可内部使用,严禁外传"
,
"<年框供应商细节表格>"
:
"在线表格https://kdocs.cn/l/camwZE63frNw"
,
"<年框供应商流程>"
:
"1.需求方发出项目需求(大型项目需比稿)
\n
2.外协根据项目需求报价,提供需求方“预算单”(按照基准单价报价,如有发现不按单价情况,解除合同不再使用)
\n
3.需求方确认预算价格,并提交OA市场活动申请
\n
4.外协现场执行
\n
5.需求方现场验收,并签署验收单(物料、设备、人员等实际清单)
\n
6.外协出具结算单(金额与验收单一致,加盖公章)、结案报告、年框合同,作为报销凭证
\n
7.外协请需求方项目负责人填写“满意度调研表”(如无,会影响年度评价)
\n
8.需求方项目经理提交报销"
,
"<市场活动结案报告内容>"
:
"1.项目简介(时间、地点、参与人数等);2.最终会议安排;3.活动各环节现场图片;4.费用相关证明材料(如执行人员、物料照片);5.活动成效汇总;6.活动原始照片/视频网络链接"
,
"<展板设计选择>"
:
"1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计"
,
"<餐费标准>"
:
"一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》"
,
""
:
""
,
}
def
substitution
(
chunks
):
def
substitution
(
chunks
):
# 翻译特殊字符
# 翻译特殊字符
...
@@ -78,7 +39,8 @@ class Worker:
...
@@ -78,7 +39,8 @@ class Worker:
cls_model_path
=
config
[
'model'
][
'cls_model_path'
]
cls_model_path
=
config
[
'model'
][
'cls_model_path'
]
local_server_address
=
config
[
'model'
][
'local_service_address'
]
local_server_address
=
config
[
'model'
][
'local_service_address'
]
reject_throttle
=
float
(
config
[
'feature_database'
][
'reject_throttle'
])
reject_throttle
=
float
(
config
[
'feature_database'
][
'reject_throttle'
])
self
.
embedding_model_path
=
config
[
'feature_database'
][
'embedding_model_path'
]
self
.
reranker_model_path
=
config
[
'feature_database'
][
'reranker_model_path'
]
if
not
llm_service_address
:
if
not
llm_service_address
:
raise
Exception
(
'llm_service_address is required in config.ini'
)
raise
Exception
(
'llm_service_address is required in config.ini'
)
if
not
cls_model_path
:
if
not
cls_model_path
:
...
@@ -152,7 +114,9 @@ class Worker:
...
@@ -152,7 +114,9 @@ class Worker:
'''微调模型回答'''
'''微调模型回答'''
logger
.
info
(
'Prompt is: {}, History is: {}'
.
format
(
query
,
history
))
logger
.
info
(
'Prompt is: {}, History is: {}'
.
format
(
query
,
history
))
response_direct
=
self
.
openapi_local_server
.
chat
(
query
,
history
)
response_direct
=
self
.
openapi_local_server
.
chat
(
query
,
history
)
return
response_direct
data
=
json
.
loads
(
response_direct
.
content
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
return
output
async
def
produce_response
(
self
,
config
,
query
,
history
,
stream
=
False
):
async
def
produce_response
(
self
,
config
,
query
,
history
,
stream
=
False
):
response
=
''
response
=
''
...
@@ -173,7 +137,10 @@ class Worker:
...
@@ -173,7 +137,10 @@ class Worker:
if
len
(
chunks
)
==
0
:
if
len
(
chunks
)
==
0
:
logger
.
debug
(
'Response by finetune model'
)
logger
.
debug
(
'Response by finetune model'
)
chunks
=
[
self
.
response_by_finetune
(
query
,
history
=
history
)]
response
=
self
.
response_by_finetune
(
query
,
history
=
history
)
data
=
json
.
loads
(
response
.
content
.
decode
(
"utf-8"
))
chunks
=
[
output
]
elif
use_template
:
elif
use_template
:
logger
.
debug
(
'Response by template'
)
logger
.
debug
(
'Response by template'
)
response
=
self
.
format_rag_result
(
chunks
,
references
,
stream
=
stream
)
response
=
self
.
format_rag_result
(
chunks
,
references
,
stream
=
stream
)
...
...
rag/feature_database.py
View file @
cf14b883
...
@@ -8,7 +8,7 @@ import hashlib
...
@@ -8,7 +8,7 @@ import hashlib
import
textract
import
textract
import
shutil
import
shutil
import
configparser
import
configparser
import
json
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
from
typing
import
List
from
typing
import
List
from
loguru
import
logger
from
loguru
import
logger
...
@@ -18,17 +18,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
...
@@ -18,17 +18,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from
langchain.vectorstores.faiss
import
FAISS
from
langchain.vectorstores.faiss
import
FAISS
from
torch.cuda
import
empty_cache
from
torch.cuda
import
empty_cache
from
bs4
import
BeautifulSoup
from
bs4
import
BeautifulSoup
from
elastic_keywords_search
import
ElasticKeywordsSearch
from
.elastic_keywords_search
import
ElasticKeywordsSearch
from
retriever
import
Retriever
from
.retriever
import
Retriever
def
check_envs
(
args
):
if
all
(
isinstance
(
item
,
int
)
for
item
in
args
.
DCU_ID
):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
','
.
join
(
map
(
str
,
args
.
DCU_ID
))
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
args
.
DCU_ID
}
"
)
else
:
logger
.
error
(
f
"The --DCU_ID argument must be a list of integers, but got
{
args
.
DCU_ID
}
"
)
raise
ValueError
(
"The --DCU_ID argument must be a list of integers"
)
class
DocumentName
:
class
DocumentName
:
...
@@ -482,32 +473,14 @@ if __name__ == '__main__':
...
@@ -482,32 +473,14 @@ if __name__ == '__main__':
log_file_path
=
os
.
path
.
join
(
args
.
work_dir
,
'application.log'
)
log_file_path
=
os
.
path
.
join
(
args
.
work_dir
,
'application.log'
)
logger
.
add
(
log_file_path
,
rotation
=
'10MB'
,
compression
=
'zip'
)
logger
.
add
(
log_file_path
,
rotation
=
'10MB'
,
compression
=
'zip'
)
check_envs
(
args
)
config
=
configparser
.
ConfigParser
()
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
config
.
read
(
args
.
config_path
)
# only init vector retriever
# only init vector retriever
embedding_model_path
=
config
.
get
(
'rag'
,
'embedding_model_path'
)
or
None
retriever
=
Retriever
(
config
)
reranker_model_path
=
config
.
get
(
'rag'
,
'reranker_model_path'
)
or
None
fs_init
=
FeatureDataBase
(
embeddings
=
retriever
.
embeddings
,
if
embedding_model_path
and
reranker_model_path
:
reranker
=
retriever
.
reranker
)
embeddings
=
HuggingFaceEmbeddings
(
model_name
=
embedding_model_path
,
model_kwargs
=
{
'device'
:
'cuda'
},
encode_kwargs
=
{
'batch_size'
:
1
,
'normalize_embeddings'
:
True
})
embeddings
.
client
=
embeddings
.
client
.
half
()
reranker_args
=
{
'model'
:
reranker_model_path
,
'top_n'
:
int
(
config
[
'rag'
][
'vector_top_k'
]),
'device'
:
'cuda'
,
'use_fp16'
:
True
}
reranker
=
BCERerank
(
**
reranker_args
)
fs_init
=
FeatureDataBase
(
embeddings
=
embeddings
,
reranker
=
reranker
)
# init es retriever, drop_old means build new one or updata the 'index_name'
# init es retriever, drop_old means build new one or updata the 'index_name'
es_url
=
config
.
get
(
'rag'
,
'es_url'
)
es_url
=
config
.
get
(
'rag'
,
'es_url'
)
...
@@ -523,22 +496,3 @@ if __name__ == '__main__':
...
@@ -523,22 +496,3 @@ if __name__ == '__main__':
files
=
file_opr
.
scan_directory
(
repo_dir
=
args
.
repo_dir
)
files
=
file_opr
.
scan_directory
(
repo_dir
=
args
.
repo_dir
)
fs_init
.
initialize
(
files
=
files
,
work_dir
=
args
.
work_dir
,
file_opr
=
file_opr
,
elastic_search
=
elastic_search
)
fs_init
.
initialize
(
files
=
files
,
work_dir
=
args
.
work_dir
,
file_opr
=
file_opr
,
elastic_search
=
elastic_search
)
file_opr
.
summarize
(
files
)
file_opr
.
summarize
(
files
)
# del fs_init
# with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f:
# positive_sample = json.load(f)
# with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f:
# negative_sample = json.load(f)
#
# with open(os.path.join(args.work_dir, 'sample', 'positive.txt'), 'r', encoding='utf-8') as file:
# positive_sample = []
# for line in file:
# positive_sample.append(line.strip())
#
# with open(os.path.join(args.work_dir, 'sample', 'negative.txt'), 'r', encoding='utf-8') as file:
# negative_sample = []
# for line in file:
# negative_sample.append(line.strip())
#
# test_reject(retriever)
server_start.py
View file @
cf14b883
...
@@ -34,7 +34,7 @@ def workflow(args):
...
@@ -34,7 +34,7 @@ def workflow(args):
query
=
input_json
[
'query'
]
query
=
input_json
[
'query'
]
history
=
input_json
.
get
(
'history'
,
[])
history
=
input_json
.
get
(
'history'
,
[])
try
:
try
:
code
,
reply
,
references
=
await
assistant
.
produce_response
(
config
,
_
,
reply
,
references
=
await
assistant
.
produce_response
(
config
,
query
=
query
,
query
=
query
,
history
=
history
)
history
=
history
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -51,7 +51,7 @@ def workflow(args):
...
@@ -51,7 +51,7 @@ def workflow(args):
async
def
event_generator
():
async
def
event_generator
():
try
:
try
:
code
,
reply
,
references
=
await
assistant
.
produce_response
(
config
,
_
,
reply
,
references
=
await
assistant
.
produce_response
(
config
,
query
=
query
,
query
=
query
,
history
=
history
,
history
=
history
,
stream
=
True
)
stream
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment