Commit 3edf4e00 authored by chenych's avatar chenych
Browse files

Add vllm stram chat code and update client.py

parent 020c2e2f
...@@ -4,6 +4,7 @@ import requests ...@@ -4,6 +4,7 @@ import requests
parse = argparse.ArgumentParser() parse = argparse.ArgumentParser()
parse.add_argument('--query', default='请写一首诗') parse.add_argument('--query', default='请写一首诗')
parse.add_argument('--use_hf', action='store_true')
args = parse.parse_args() args = parse.parse_args()
print(args.query) print(args.query)
...@@ -14,8 +15,10 @@ data = { ...@@ -14,8 +15,10 @@ data = {
} }
json_str = json.dumps(data) json_str = json.dumps(data)
if args.use_hf:
response = requests.post("http://localhost:8888/inference", headers=headers, data=json_str.encode("utf-8"), verify=False) response = requests.post("http://localhost:8888/hf_inference", headers=headers, data=json_str.encode("utf-8"), verify=False)
else:
response = requests.post("http://localhost:8888/vllm_inference", headers=headers, data=json_str.encode("utf-8"), verify=False)
str_response = response.content.decode("utf-8") str_response = response.content.decode("utf-8")
print(json.loads(str_response)) print(json.loads(str_response))
...@@ -8,8 +8,8 @@ import asyncio ...@@ -8,8 +8,8 @@ import asyncio
from loguru import logger from loguru import logger
from aiohttp import web from aiohttp import web
# from multiprocessing import Value # from multiprocessing import Value
from transformers import AutoModelForCausalLM, Autotokenzier
from transformers import AutoModelForCausalLM, AutoTokenizer
COMMON = { COMMON = {
...@@ -65,27 +65,39 @@ def build_history_messages(prompt, history, system: str = None): ...@@ -65,27 +65,39 @@ def build_history_messages(prompt, history, system: str = None):
return history_messages return history_messages
def substitution(output_text):
# 翻译特殊字符
import re
if isinstance(output_text, list):
output_text = output_text[0]
matchObj = re.split('.*(<.*>).*', output_text, re.M|re.I)
if len(matchObj) > 1:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
return output_text
class LLMInference: class LLMInference:
def __init__(self, def __init__(self,
model, model,
tokenizer, tokenzier,
sampling_params,
device: str = 'cuda', device: str = 'cuda',
use_vllm: bool = False,
) -> None: ) -> None:
self.device = device self.device = device
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenzier = tokenzier
self.sampling_params = sampling_params
self.use_vllm = use_vllm
def generate_response(self, prompt, history=[]): def generate_response(self, prompt, history=[]):
print("generate") print("generate")
output_text = '' output_text = ''
error = '' error = ''
time_tokenizer = time.time() time_tokenzier = time.time()
try: try:
output_text = self.chat(prompt, history) output_text = self.chat(prompt, history)
...@@ -96,144 +108,170 @@ class LLMInference: ...@@ -96,144 +108,170 @@ class LLMInference:
time_finish = time.time() time_finish = time.time()
logger.debug('output_text:{} \ntimecost {} '.format(output_text, logger.debug('output_text:{} \ntimecost {} '.format(output_text,
time_finish - time_tokenizer)) time_finish - time_tokenzier))
return output_text, error return output_text, error
def substitution(self, output_text): def chat(self, messages, history=[]):
# 翻译特殊字符
import re
matchObj = re.split('.*(<.*>).*', output_text, re.M|re.I)
if len(matchObj) > 1:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
return output_text
def chat(self, prompt: str, history=[]):
'''单轮问答''' '''单轮问答'''
logger.info("****************** in chat ******************") logger.info("****************** in chat ******************")
messages = [{"role": "user", "content": prompt}]
try: try:
if self.use_vllm: # transformers
## vllm input_ids = self.tokenzier.apply_chat_template(
logger.info("****************** use vllm ******************") messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
prompt_token_ids = [self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)] outputs = self.model.generate(
logger.info(f"before generate {messages}") input_ids,
outputs = self.model.generate(prompt_token_ids=prompt_token_ids, sampling_params=self.sampling_params) max_new_tokens=1024,
)
output_text = []
for output in outputs: response = outputs[0][input_ids.shape[-1]:]
prompt = output.prompt generated_text = self.tokenzier.decode(response, skip_special_tokens=True)
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_text = substitution(generated_text)
generated_text_ = self.substitution(generated_text) logger.info(f"using transformers, output_text {output_text}")
output_text.append(generated_text_) return output_text
logger.info(f"using vllm, output_text {output_text}")
return ''.join(output_text)
else:
# transformers
input_ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
outputs = self.model.generate(
input_ids,
max_new_tokens=1024,
)
response = outputs[0][input_ids.shape[-1]:]
generated_text = self.tokenizer.decode(response, skip_special_tokens=True)
output_text = self.substitution(generated_text)
logger.info(f"using transformers, output_text {output_text}")
return output_text
except Exception as e: except Exception as e:
logger.error(f"chat inference failed, {e}") logger.error(f"chat inference failed, {e}")
def chat_stream(self, prompt: str, history=[]): def chat_stream(self, messages, history=[]):
'''流式服务''' '''流式服务'''
# HuggingFace # HuggingFace
logger.info("****************** in chat stream *****************") logger.info("****************** in chat stream *****************")
current_length = 0 current_length = 0
messages = [{"role": "user", "content": prompt}]
logger.info(f"stream_chat messages {messages}") logger.info(f"stream_chat messages {messages}")
for response, _, _ in self.model.stream_chat(self.tokenizer, messages, history=history, for response, _, _ in self.model.stream_chat(self.tokenzier, messages, history=history,
max_length=1024, max_length=1024,
past_key_values=None, past_key_values=None,
return_past_key_values=True): return_past_key_values=True):
output_text = response[current_length:] output_text = response[current_length:]
output_text = self.substitution(output_text) output_text = substitution(output_text)
logger.info(f"using transformers chat_stream, Prompt: {prompt!r}, Generated text: {output_text!r}") logger.info(f"using transformers chat_stream, Prompt: {messages!r}, Generated text: {output_text!r}")
yield output_text yield output_text
current_length = len(response) current_length = len(response)
def init_model(model_path, use_vllm=False, tp_size=1): def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models ## init models
# huggingface logger.info("Starting initial model of LLM")
logger.info("Starting initial model of Llama")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if use_vllm:
# vllm
from vllm import LLM, SamplingParams
tokenzier = Autotokenzier.from_pretrained(model_path, trust_remote_code=True)
if use_vllm:
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
sampling_params = SamplingParams(temperature=1, sampling_params = SamplingParams(temperature=1,
top_p=0.95, top_p=0.95,
max_tokens=1024, max_tokens=1024,
stop_token_ids=[tokenizer.eos_token_id]) early_stopping=False,
stop_token_ids=[tokenzier.eos_token_id]
model = LLM(model=model_path, )
trust_remote_code=True, # vLLM基础配置
enforce_eager=True, args = AsyncEngineArgs(model_path)
dtype="float16", args.worker_use_ray = False
tensor_parallel_size=tp_size) args.engine_use_ray = False
return model, tokenizer, sampling_params args.tokenzier = model_path
args.tensor_parallel_size = tensor_parallel_size
args.trust_remote_code = True
args.enforce_eager = True
args.max_model_len = 1024
args.dtype = 'float16'
# 加载模型
engine = AsyncLLMEngine.from_engine_args(args)
return engine, tokenzier, sampling_params
else: else:
# huggingface
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval() model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval()
return model, tokenizer, None return model, tokenzier, None
def llm_inference(args):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
use_vllm = config.getboolean('llm', 'use_vllm')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size) def hf_inference(bind_port, model, tokenzier, stream_chat):
'''启动 hf Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
llm_infer = LLMInference(model, tokenzier)
async def inference(request): async def inference(request):
start = time.time() start = time.time()
input_json = await request.json() input_json = await request.json()
llm_infer = LLMInference(model,
tokenzier,
sampling_params,
use_vllm=use_vllm)
prompt = input_json['query'] prompt = input_json['query']
history = input_json['history'] history = input_json['history']
logger.info(f"prompt {prompt}")
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use transformers ******************")
if stream_chat: if stream_chat:
text = await asyncio.to_thread(llm_infer.chat_stream, prompt=prompt, history=history) text = await asyncio.to_thread(llm_infer.chat_stream, messages=messages, history=history)
else: else:
text = await asyncio.to_thread(llm_infer.chat, prompt=prompt, history=history) text = await asyncio.to_thread(llm_infer.chat, messages=messages, history=history)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, time.time() - start))
return web.json_response({'text': text})
app = web.Application()
app.add_routes([web.post('/hf_inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
import uuid
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
# history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
## generate template
input_text = tokenzier.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
logger.info(f"The input_text is {input_text}")
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield web.json_response({'text': text})
if stream_chat:
logger.info("****************** in chat stream *****************")
return StreamingResponse(stream_results())
# Non-streaming case
logger.info("****************** in chat ******************")
final_output = None
async for request_output in results_generator:
# if await request.is_disconnected():
# # Abort the request if the client disconnects.
# await engine.abort(request_id)
# return Response(status_code=499)
final_output = request_output
assert final_output is not None
text = [output.text for output in final_output.outputs]
end = time.time() end = time.time()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start)) logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text}) return web.json_response({'text': output_text})
app = web.Application() app = web.Application()
app.add_routes([web.post('/inference', inference)]) app.add_routes([web.post('/vllm_inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port) web.run_app(app, host='0.0.0.0', port=bind_port)
...@@ -292,7 +330,21 @@ def parse_args(): ...@@ -292,7 +330,21 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
set_envs(args.DCU_ID) set_envs(args.DCU_ID)
llm_inference(args) # configs
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
if use_vllm:
vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat)
else:
hf_inference(bind_port, model, tokenzier, sampling_params, stream_chat)
# infer_test(args) # infer_test(args)
......
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
# huggingface
logger.info("Starting initial model of Llama - vllm")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# vllm
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
early_stopping=False,
stop_token_ids=[tokenizer.eos_token_id]
)
# vLLM基础配置
args = AsyncEngineArgs(model_path)
args.worker_use_ray = False
args.engine_use_ray = False
args.tokenizer = model_path
args.tensor_parallel_size = tensor_parallel_size
args.trust_remote_code = True
args.enforce_eager = True
args.max_model_len = 1024
args.dtype = 'float16'
# 加载模型
engine = AsyncLLMEngine.from_engine_args(args)
return engine, tokenizer, sampling_params
def llm_inference(args):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
config = configparser.ConfigParser()
config.read(args.config_path)
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
use_vllm = config.getboolean('llm', 'use_vllm')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenizer, sampling_params = init_model(model_path, tensor_parallel_size)
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
logger.info(f"before generate {messages}")
## 1
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
print(text)
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream_chat:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
async for request_output in results_generator:
# if await request.is_disconnected():
# # Abort the request if the client disconnects.
# await engine.abort(request_id)
# return Response(status_code=499)
final_output = request_output
assert final_output is not None
text = [output.text for output in final_output.outputs]
end = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text})
app = web.Application()
app.add_routes([web.post('/inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment