Unverified Commit 60d03021 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

launch gradio server directly with hf model (#856)

* launch gradio server directly with hf model

* end session

* end session

* fix api_server backend for gradio

* fix out of boundary index

* remove log
parent 99f4156f
......@@ -120,8 +120,7 @@ class AsyncEngine:
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True,
stop=True):
sequence_end=True):
pass
self.id2step[str(session_id)] = 0
if str(session_id) in self.id2generator and self.id2generator[str(
......@@ -265,7 +264,11 @@ class AsyncEngine:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = None
if self.id2step[str(session_id)] + len(
if stop is True:
self.stop_session(session_id)
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
elif self.id2step[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len:
finish_reason = 'length'
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
......
# Copyright (c) OpenMMLab. All rights reserved.
import time
from threading import Lock
from typing import Sequence
......@@ -89,15 +88,24 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
# stop the session
for out in get_streaming_response(
'',
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
stop=True,
interactive_mode=True):
pass
# end the session
for out in get_streaming_response(
'',
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
stop=True):
interactive_mode=False):
pass
time.sleep(0.5)
# resume the session
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
......
......@@ -6,6 +6,7 @@ def run(model_path_or_server: str,
server_port: int = 6006,
batch_size: int = 32,
tp: int = 1,
model_name: str = None,
**kwargs):
"""chat with AI assistant through web ui.
......@@ -31,8 +32,13 @@ def run(model_path_or_server: str,
run_triton_server(model_path_or_server, server_name, server_port)
else:
from lmdeploy.serve.gradio.turbomind_coupled import run_local
run_local(model_path_or_server, server_name, server_port, batch_size,
tp, **kwargs)
run_local(model_path_or_server,
model_name=model_name,
server_name=server_name,
server_port=server_port,
batch_size=batch_size,
tp=tp,
**kwargs)
if __name__ == '__main__':
......
......@@ -69,13 +69,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
"""
state_chatbot = []
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=1,
stream_response=True,
sequence_start=False,
sequence_end=True):
pass
InterFace.async_engine.end_session(session_id)
return (state_chatbot, state_chatbot, gr.Textbox.update(value=''))
......@@ -90,15 +84,9 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
reset_btn (gr.Button): the reset button
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, enable_btn)
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=0,
stream_response=True,
sequence_start=False,
sequence_end=False,
stop=True):
pass
yield (state_chatbot, disable_btn, disable_btn)
InterFace.async_engine.stop_session(session_id)
InterFace.async_engine.end_session(session_id)
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment