Commit de6f9f97 authored by Rayyyyy's avatar Rayyyyy
Browse files

jifu v1.0

parent 74767f88
...@@ -11,4 +11,4 @@ reranker_model_path = /path/to/your/bce-reranker-base_v1 ...@@ -11,4 +11,4 @@ reranker_model_path = /path/to/your/bce-reranker-base_v1
[llm] [llm]
local_llm_path = /path/to/your/internlm-chat-7b local_llm_path = /path/to/your/internlm-chat-7b
accelerate = False use_vllm = False
\ No newline at end of file \ No newline at end of file
from vllm import LLM, SamplingParams
import os import os
import time
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from vllm import LLM, SamplingParams
def infer_hf_chatglm(model_path, prompt):
'''transformers 推理 chatglm2'''
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto").half().cuda()
model = model.eval()
start_time = time.time()
generated_text, _ = model.chat(tokenizer, prompt, history=[])
print("chat time ", time.time()- start_time)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
os.environ["CUDA_VISIBLE_DEVICES"] = '7' def infer_hf_llama3(model_path, prompt):
'''transformers 推理 llama3'''
input_query = {"role": "user", "content": prompt}
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto")
prompt = '' input_ids = tokenizer.apply_chat_template(
model_path = '' [input_query,], add_generation_prompt=True, return_tensors="pt").to(model.device)
sampling_params = SamplingParams(temperature=1, top_p=0.95) outputs = model.generate(
llm = LLM(model=model_path, input_ids,
trust_remote_code=True, max_new_tokens=512,
enforce_eager=True, do_sample=True,
tensor_parallel_size=1) temperature=1,
top_p=0.95,
)
outputs = llm.generate(prompt, sampling_params) response = outputs[0][input_ids.shape[-1]:]
for output in outputs: generated_text = tokenizer.decode(response, skip_special_tokens=True)
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
def infer_vllm_llama3(model_path, message, tp_size=1, max_model_len=1024):
'''vllm 推理 llama3'''
tokenizer = AutoTokenizer.from_pretrained(model_path)
messages = [{"role": "user", "content": message}]
print(f"Prompt: {messages!r}")
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
stop_token_ids=[tokenizer.eos_token_id])
llm = LLM(model=model_path,
max_model_len=max_model_len,
trust_remote_code=True,
enforce_eager=True,
dtype="float16",
tensor_parallel_size=tp_size)
# generate answer
start_time = time.time()
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
print("total infer time", time.time() - start_time)
# results
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
def infer_vllm_chatglm(model_path, message, tp_size=1):
'''vllm 推理 chatglm2'''
sampling_params = SamplingParams(temperature=1.0,
top_p=0.9,
max_tokens=1024)
llm = LLM(model=model_path,
trust_remote_code=True,
enforce_eager=True,
dtype="float16",
tensor_parallel_size=tp_size)
# generate answer
print(f"chatglm2 Prompt: {message!r}")
outputs = llm.generate(message, sampling_params=sampling_params)
# results
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='')
parser.add_argument('--query', default="DCU是什么?", help='提问的问题.')
parser.add_argument('--use_hf', action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
is_llama = True if "llama" in args.model_path else False
print("Is llama", is_llama)
if args.use_hf:
# transformers
if is_llama:
infer_hf_llama3(args.model_path, args.query)
else:
infer_hf_chatglm(args.model_path, args.query)
else:
# vllm
if is_llama:
infer_vllm_llama3(args.model_path, args.query)
else:
infer_vllm_chatglm(args.model_path, args.query)
...@@ -21,15 +21,6 @@ from bs4 import BeautifulSoup ...@@ -21,15 +21,6 @@ from bs4 import BeautifulSoup
from .retriever import CacheRetriever, Retriever from .retriever import CacheRetriever, Retriever
def check_envs(args):
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
class DocumentName: class DocumentName:
def __init__(self, directory: str, name: str, category: str): def __init__(self, directory: str, name: str, category: str):
...@@ -495,7 +486,6 @@ if __name__ == '__main__': ...@@ -495,7 +486,6 @@ if __name__ == '__main__':
log_file_path = os.path.join(args.work_dir, 'application.log') log_file_path = os.path.join(args.work_dir, 'application.log')
logger.add(log_file_path, rotation='10MB', compression='zip') logger.add(log_file_path, rotation='10MB', compression='zip')
check_envs(args)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
......
...@@ -30,6 +30,7 @@ class ErrorCode(Enum): ...@@ -30,6 +30,7 @@ class ErrorCode(Enum):
SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota' SEARCH_FAIL = 15, 'Search fail, please check TOKEN and quota'
NOT_FIND_RELATED_DOCS = 16, 'No relevant documents found, the following answer is generated directly by LLM.' NOT_FIND_RELATED_DOCS = 16, 'No relevant documents found, the following answer is generated directly by LLM.'
NON_COMPLIANCE_QUESTION = 17, 'Non-compliance question, refusing to answer.' NON_COMPLIANCE_QUESTION = 17, 'Non-compliance question, refusing to answer.'
SCORE_ERROR = 18, 'Get score error.'
def __new__(cls, value, description): def __new__(cls, value, description):
"""Create new instance of ErrorCode.""" """Create new instance of ErrorCode."""
......
...@@ -10,15 +10,6 @@ from loguru import logger ...@@ -10,15 +10,6 @@ from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
def check_envs(args):
if all(isinstance(item, int) for item in args.DCU_ID):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.gita}")
else:
logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
raise ValueError("The --DCU_ID argument must be a list of integers")
def build_history_messages(prompt, history, system: str = None): def build_history_messages(prompt, history, system: str = None):
history_messages = [] history_messages = []
if system is not None and len(system) > 0: if system is not None and len(system) > 0:
...@@ -32,8 +23,8 @@ def build_history_messages(prompt, history, system: str = None): ...@@ -32,8 +23,8 @@ def build_history_messages(prompt, history, system: str = None):
class InferenceWrapper: class InferenceWrapper:
def __init__(self, model_path: str, accelerate: bool, stream_chat: bool): def __init__(self, model_path: str, use_vllm: bool, stream_chat: bool):
self.accelerate = accelerate self.use_vllm = use_vllm
self.stream_chat = stream_chat self.stream_chat = stream_chat
# huggingface # huggingface
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
...@@ -41,11 +32,12 @@ class InferenceWrapper: ...@@ -41,11 +32,12 @@ class InferenceWrapper:
trust_remote_code=True, trust_remote_code=True,
device_map='auto', device_map='auto',
torch_dtype=torch.bfloat16).eval() torch_dtype=torch.bfloat16).eval()
if self.accelerate: if self.use_vllm:
try: try:
# fastllm # fastllm
from fastllm_pytools import llm from fastllm_pytools import llm
if self.stream_chat: if self.stream_chat:
# fastllm的流式初始化
self.model = llm.model(model_path) self.model = llm.model(model_path)
else: else:
self.model = llm.from_hf(self.model, self.tokenizer, dtype="float16").cuda() self.model = llm.from_hf(self.model, self.tokenizer, dtype="float16").cuda()
...@@ -55,9 +47,10 @@ class InferenceWrapper: ...@@ -55,9 +47,10 @@ class InferenceWrapper:
def chat(self, prompt: str, history=[]): def chat(self, prompt: str, history=[]):
'''问答'''
output_text = '' output_text = ''
try: try:
if self.accelerate: if self.use_vllm:
output_text = self.model.response(prompt) output_text = self.model.response(prompt)
else: else:
output_text, _ = self.model.chat(self.tokenizer, output_text, _ = self.model.chat(self.tokenizer,
...@@ -71,7 +64,7 @@ class InferenceWrapper: ...@@ -71,7 +64,7 @@ class InferenceWrapper:
def chat_stream(self, prompt: str, history=[]): def chat_stream(self, prompt: str, history=[]):
'''流式服务''' '''流式服务'''
if self.accelerate: if self.use_vllm:
from fastllm_pytools import llm from fastllm_pytools import llm
# Fastllm # Fastllm
for response in self.model.stream_response(prompt, history=[]): for response in self.model.stream_response(prompt, history=[]):
...@@ -94,13 +87,13 @@ class LLMInference: ...@@ -94,13 +87,13 @@ class LLMInference:
model_path: str, model_path: str,
tensor_parallel_size: int, tensor_parallel_size: int,
device: str = 'cuda', device: str = 'cuda',
accelerate: bool = False use_vllm: bool = False
) -> None: ) -> None:
self.device = device self.device = device
self.inference = InferenceWrapper(model_path, self.inference = InferenceWrapper(model_path,
accelerate=accelerate, use_vllm=use_vllm,
tensor_parallel_size=tensor_parallel_size) tensor_parallel_size=tensor_parallel_size)
def generate_response(self, prompt, history=[]): def generate_response(self, prompt, history=[]):
...@@ -123,38 +116,16 @@ class LLMInference: ...@@ -123,38 +116,16 @@ class LLMInference:
return output_text, error return output_text, error
def infer_test(args):
config = configparser.ConfigParser()
config.read(args.config_path)
model_path = config['llm']['local_llm_path']
accelerate = config.getboolean('llm', 'accelerate')
inference_wrapper = InferenceWrapper(model_path,
accelerate=accelerate,
tensor_parallel_size=1)
# prompt = "hello,please introduce yourself..."
prompt = "你好,请介绍北京大学"
history = []
time_first = time.time()
output_text = inference_wrapper.chat(prompt, use_history=True, history=history)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
prompt, output_text, time_second - time_first))
def llm_inference(args): def llm_inference(args):
""" '''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应.
"""
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
bind_port = int(config['default']['bind_port']) bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path'] model_path = config['llm']['local_llm_path']
accelerate = config.getboolean('llm', 'accelerate') use_vllm = config.getboolean('llm', 'use_vllm')
inference_wrapper = InferenceWrapper(model_path, inference_wrapper = InferenceWrapper(model_path,
accelerate=accelerate, use_vllm=use_vllm,
stream_chat=args.stream_chat) stream_chat=args.stream_chat)
async def inference(request): async def inference(request):
start = time.time() start = time.time()
...@@ -175,6 +146,15 @@ def llm_inference(args): ...@@ -175,6 +146,15 @@ def llm_inference(args):
web.run_app(app, host='0.0.0.0', port=bind_port) web.run_app(app, host='0.0.0.0', port=bind_port)
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
def parse_args(): def parse_args():
'''参数''' '''参数'''
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -189,8 +169,9 @@ def parse_args(): ...@@ -189,8 +169,9 @@ def parse_args():
help='提问的问题.') help='提问的问题.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
default=[0], type=str,
help='设置DCU') default='0',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
parser.add_argument( parser.add_argument(
'--stream_chat', '--stream_chat',
action='store_true', action='store_true',
...@@ -201,8 +182,7 @@ def parse_args(): ...@@ -201,8 +182,7 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
check_envs(args) set_envs(args.DCU_ID)
#infer_test(args)
llm_inference(args) llm_inference(args)
......
...@@ -66,7 +66,7 @@ class Worker: ...@@ -66,7 +66,7 @@ class Worker:
self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。' # noqa E501 self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。' # noqa E501
self.SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501 self.SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501
self.PERPLESITY_TEMPLATE = '“question:{} answer:{}”\n阅读以上对话,answer 是否在表达自己不知道,回答越全面得分越少,用0~10表示,不要解释直接给出得分。\n判断标准:准确回答问题得 0 分;答案详尽得 1 分;知道部分答案但有不确定信息得 8 分;知道小部分答案但推荐求助其他人得 9 分;不知道任何答案直接推荐求助别人得 10 分。直接打分不要解释。' # noqa E501 self.PERPLESITY_TEMPLATE = '“question:{} answer:{}”\n阅读以上对话,answer 是否在表达自己不知道,回答越全面得分越少,用0~10表示,不要解释直接给出得分。\n判断标准:准确回答问题得 0 分;答案详尽得 1 分;知道部分答案但有不确定信息得 8 分;知道小部分答案但推荐求助其他人得 9 分;不知道任何答案直接推荐求助别人得 10 分。直接打分不要解释。' # noqa E501
self.SUMMARIZE_TEMPLATE = '{} \n 仔细阅读以上内容,总结简短有力点' # noqa E501 self.SUMMARIZE_TEMPLATE = '{} \n 仔细阅读以上内容,总结简短有力点' # noqa E501
self.GENERATE_TEMPLATE = '“{}” \n问题:“{}” \n请仔细阅读上述文字, 并使用markdown格式回答问题,直接给出回答不做任何解释。' # noqa E501 self.GENERATE_TEMPLATE = '“{}” \n问题:“{}” \n请仔细阅读上述文字, 并使用markdown格式回答问题,直接给出回答不做任何解释。' # noqa E501
self.MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题' self.MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题'
...@@ -92,8 +92,16 @@ class Worker: ...@@ -92,8 +92,16 @@ class Worker:
def response_direct_by_llm(self, query): def response_direct_by_llm(self, query):
# Compliant check # Compliant check
import ast
prompt = self.SECURITY_TEMAPLTE.format(query) prompt = self.SECURITY_TEMAPLTE.format(query)
score = self.agent.call_llm_response(prompt=prompt) scores = self.agent.call_llm_response(prompt=prompt)
try:
score_list = ast.literal_eval(scores)
score = int(score_list[0])
except Exception as e:
logger.error("score:{}, error:{}".format(score, e))
return ErrorCode.SCORE_ERROR, e, None
logger.debug("score:{}, prompt:{}".format(score, prompt)) logger.debug("score:{}, prompt:{}".format(score, prompt))
if int(score) > 5: if int(score) > 5:
return ErrorCode.NON_COMPLIANCE_QUESTION, "您的问题中涉及敏感话题,请重新提问。", None return ErrorCode.NON_COMPLIANCE_QUESTION, "您的问题中涉及敏感话题,请重新提问。", None
...@@ -147,4 +155,4 @@ class Worker: ...@@ -147,4 +155,4 @@ class Worker:
return ErrorCode.SUCCESS, response, references return ErrorCode.SUCCESS, response, references
else: else:
return self.response_direct_by_llm(query) return self.response_direct_by_llm(query)
\ No newline at end of file
...@@ -32,7 +32,7 @@ def parse_args(): ...@@ -32,7 +32,7 @@ def parse_args():
default=False, default=False,
help='部署LLM推理服务.') help='部署LLM推理服务.')
parser.add_argument( parser.add_argument(
'--accelerate', '--use_vllm',
default=False, default=False,
type=bool, type=bool,
help='LLM推理是否启用加速' help='LLM推理是否启用加速'
...@@ -71,7 +71,7 @@ def run(): ...@@ -71,7 +71,7 @@ def run():
server_process = Process(target=llm_inference, server_process = Process(target=llm_inference,
args=(args.config_path, args=(args.config_path,
len(args.DCU_ID), len(args.DCU_ID),
args.accelerate, args.use_vllm,
server_ready)) server_ready))
server_process.daemon = True server_process.daemon = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment