Commit 020c2e2f authored by Rayyyyy's avatar Rayyyyy
Browse files

Add codes

parent 8f65b603
import json
import argparse
import requests
parse = argparse.ArgumentParser()
parse.add_argument('--query', default='请写一首诗')
args = parse.parse_args()
print(args.query)
headers = {"Content-Type": "application/json"}
data = {
"query": args.query,
"history": []
}
json_str = json.dumps(data)
response = requests.post("http://localhost:8888/inference", headers=headers, data=json_str.encode("utf-8"), verify=False)
str_response = response.content.decode("utf-8")
print(json.loads(str_response))
......@@ -2,11 +2,12 @@ import time
import os
import configparser
import argparse
import torch
# import torch
import asyncio
from loguru import logger
from aiohttp import web
from multiprocessing import Value
# from multiprocessing import Value
from transformers import AutoModelForCausalLM, AutoTokenizer
......@@ -72,7 +73,6 @@ class LLMInference:
sampling_params,
device: str = 'cuda',
use_vllm: bool = False,
stream_chat: bool = False
) -> None:
self.device = device
......@@ -80,7 +80,6 @@ class LLMInference:
self.tokenizer = tokenizer
self.sampling_params = sampling_params
self.use_vllm = use_vllm
self.stream_chat = stream_chat
def generate_response(self, prompt, history=[]):
print("generate")
......@@ -120,8 +119,9 @@ class LLMInference:
try:
if self.use_vllm:
## vllm
logger.info("****************** use vllm ******************")
prompt_token_ids = [self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
logger.info(f"before generate {messages}")
outputs = self.model.generate(prompt_token_ids=prompt_token_ids, sampling_params=self.sampling_params)
output_text = []
......@@ -157,12 +157,18 @@ class LLMInference:
def chat_stream(self, prompt: str, history=[]):
'''流式服务'''
# HuggingFace
logger.info("****************** in chat stream *****************")
current_length = 0
for response, _, _ in self.model.stream_chat(self.tokenizer, prompt, history=history,
past_key_values=None,
return_past_key_values=True):
messages = [{"role": "user", "content": prompt}]
logger.info(f"stream_chat messages {messages}")
for response, _, _ in self.model.stream_chat(self.tokenizer, messages, history=history,
max_length=1024,
past_key_values=None,
return_past_key_values=True):
output_text = response[current_length:]
output_text = self.substitution(output_text)
logger.info(f"using transformers chat_stream, Prompt: {prompt!r}, Generated text: {output_text!r}")
yield output_text
current_length = len(response)
......@@ -213,14 +219,15 @@ def llm_inference(args):
llm_infer = LLMInference(model,
tokenzier,
sampling_params,
use_vllm=use_vllm,
stream_chat=stream_chat)
use_vllm=use_vllm)
prompt = input_json['query']
history = input_json['history']
logger.info(f"prompt {prompt}")
if stream_chat:
text = llm_infer.stream_chat(prompt=prompt, history=history)
text = await asyncio.to_thread(llm_infer.chat_stream, prompt=prompt, history=history)
else:
text = llm_infer.chat(prompt=prompt, history=history)
text = await asyncio.to_thread(llm_infer.chat, prompt=prompt, history=history)
end = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text})
......@@ -243,8 +250,7 @@ def infer_test(args):
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
llm_infer = LLMInference(model,
tokenzier,
use_vllm=use_vllm,
stream_chat=stream_chat)
use_vllm=use_vllm)
time_first = time.time()
output_text = llm_infer.chat(args.query)
......@@ -272,7 +278,7 @@ def parse_args():
help='config目录')
parser.add_argument(
'--query',
default=['写一首诗'],
default='写一首诗',
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment