You need to sign in or sign up before continuing.
Unverified Commit 11d10930 authored by aisensiy's avatar aisensiy Committed by GitHub
Browse files

Manage session id using random int for gradio local mode (#553)



* Use session id from gradio state

* use a new session id after reset

* rename session id like a state

* update comments

* reformat files

* init session id on block loaded

* use auto increased session id

* remove session id textbox

* apply to api_server and tritonserver

* update docstring

* add lock for safety

---------
Co-authored-by: default avatarAllentDan <AllentDan@yeah.net>
parent 85d2f662
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import threading
import time import time
from threading import Lock
from typing import Sequence from typing import Sequence
import gradio as gr import gradio as gr
...@@ -8,35 +8,27 @@ import gradio as gr ...@@ -8,35 +8,27 @@ import gradio as gr
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_client import (get_model_list, from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response) get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id
class InterFace: class InterFace:
api_server_url: str = None api_server_url: str = None
global_session_id: int = 0
lock = Lock()
def chat_stream_restful( def chat_stream_restful(instruction: str, state_chatbot: Sequence,
instruction: str, cancel_btn: gr.Button, reset_btn: gr.Button,
state_chatbot: Sequence, session_id: int):
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
"""Chat with AI assistant. """Chat with AI assistant.
Args: Args:
instruction (str): user's prompt instruction (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user session_id (int): the session id
""" """
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn, yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
f'{bot_summarized_response}'.strip())
for response, tokens, finish_reason in get_streaming_response( for response, tokens, finish_reason in get_streaming_response(
instruction, instruction,
...@@ -56,27 +48,21 @@ def chat_stream_restful( ...@@ -56,27 +48,21 @@ def chat_stream_restful(
state_chatbot[-1] = (state_chatbot[-1][0], state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response state_chatbot[-1][1] + response
) # piece by piece ) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn, yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn, yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
f'{bot_summarized_response}'.strip())
def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request): session_id: int):
"""reset the session. """reset the session.
Args: Args:
instruction_txtbox (str): user's prompt instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user session_id (int): the session id
""" """
state_chatbot = [] state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
for response, tokens, finish_reason in get_streaming_response( for response, tokens, finish_reason in get_streaming_response(
'', '',
...@@ -94,18 +80,15 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, ...@@ -94,18 +80,15 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request): reset_btn: gr.Button, session_id: int):
"""stop the session. """stop the session.
Args: Args:
instruction_txtbox (str): user's prompt instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user session_id (int): the session id
""" """
yield (state_chatbot, disable_btn, disable_btn) yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
for out in get_streaming_response( for out in get_streaming_response(
'', '',
...@@ -152,6 +135,7 @@ def run_api_server(api_server_url: str, ...@@ -152,6 +135,7 @@ def run_api_server(api_server_url: str,
with gr.Blocks(css=CSS, theme=THEME) as demo: with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([]) state_chatbot = gr.State([])
state_session_id = gr.State(0)
with gr.Column(elem_id='container'): with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground') gr.Markdown('## LMDeploy Playground')
...@@ -164,25 +148,34 @@ def run_api_server(api_server_url: str, ...@@ -164,25 +148,34 @@ def run_api_server(api_server_url: str,
cancel_btn = gr.Button(value='Cancel', interactive=False) cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset') reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit( send_event = instruction_txtbox.submit(chat_stream_restful, [
chat_stream_restful, instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn], state_session_id
[state_chatbot, chatbot, cancel_btn, reset_btn]) ], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit( instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''), lambda: gr.Textbox.update(value=''),
[], [],
[instruction_txtbox], [instruction_txtbox],
) )
cancel_btn.click(cancel_restful_func, cancel_btn.click(
[state_chatbot, cancel_btn, reset_btn], cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn], [state_chatbot, cancel_btn, reset_btn, state_session_id],
cancels=[send_event]) [state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_restful_func, reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot], [instruction_txtbox, state_chatbot, state_session_id],
[state_chatbot, chatbot, instruction_txtbox], [state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event]) cancels=[send_event])
def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}') print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100, demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch( api_open=True).launch(
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import threading
from functools import partial from functools import partial
from threading import Lock
from typing import Sequence from typing import Sequence
import gradio as gr import gradio as gr
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot from lmdeploy.serve.turbomind.chatbot import Chatbot
class InterFace:
global_session_id: int = 0
lock = Lock()
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button, cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
request: gr.Request):
"""Chat with AI assistant. """Chat with AI assistant.
Args: Args:
...@@ -22,12 +25,9 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, ...@@ -22,12 +25,9 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
llama_chatbot (Chatbot): the instance of a chatbot llama_chatbot (Chatbot): the instance of a chatbot
cancel_btn (bool): enable the cancel button or not cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user session_id (int): the session id
""" """
instruction = state_chatbot[-1][0] instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_response = llama_chatbot.stream_infer( bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}') session_id, instruction, f'{session_id}-{len(state_chatbot)}')
...@@ -92,6 +92,7 @@ def run_triton_server(triton_server_addr: str, ...@@ -92,6 +92,7 @@ def run_triton_server(triton_server_addr: str,
llama_chatbot = gr.State( llama_chatbot = gr.State(
Chatbot(triton_server_addr, log_level=log_level, display=True)) Chatbot(triton_server_addr, log_level=log_level, display=True))
state_chatbot = gr.State([]) state_chatbot = gr.State([])
state_session_id = gr.State(0)
model_name = llama_chatbot.value.model_name model_name = llama_chatbot.value.model_name
reset_all = partial(reset_all_func, reset_all = partial(reset_all_func,
model_name=model_name, model_name=model_name,
...@@ -110,10 +111,10 @@ def run_triton_server(triton_server_addr: str, ...@@ -110,10 +111,10 @@ def run_triton_server(triton_server_addr: str,
send_event = instruction_txtbox.submit( send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot], add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then( [instruction_txtbox, state_chatbot]).then(chat_stream, [
chat_stream, state_chatbot, llama_chatbot, cancel_btn, reset_btn,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn], state_session_id
[state_chatbot, chatbot, cancel_btn, reset_btn]) ], [state_chatbot, chatbot, cancel_btn, reset_btn])
cancel_btn.click(cancel_func, cancel_btn.click(cancel_func,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn], [state_chatbot, llama_chatbot, cancel_btn, reset_btn],
...@@ -125,6 +126,14 @@ def run_triton_server(triton_server_addr: str, ...@@ -125,6 +126,14 @@ def run_triton_server(triton_server_addr: str,
[llama_chatbot, state_chatbot, chatbot, instruction_txtbox], [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event]) cancels=[send_event])
def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}') print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10, max_threads=10,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import threading from threading import Lock
from typing import Sequence from typing import Sequence
import gradio as gr import gradio as gr
from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
class InterFace: class InterFace:
async_engine: AsyncEngine = None async_engine: AsyncEngine = None
global_session_id: int = 0
lock = Lock()
async def chat_stream_local( async def chat_stream_local(
...@@ -18,25 +19,20 @@ async def chat_stream_local( ...@@ -18,25 +19,20 @@ async def chat_stream_local(
state_chatbot: Sequence, state_chatbot: Sequence,
cancel_btn: gr.Button, cancel_btn: gr.Button,
reset_btn: gr.Button, reset_btn: gr.Button,
request: gr.Request, session_id: int,
): ):
"""Chat with AI assistant. """Chat with AI assistant.
Args: Args:
instruction (str): user's prompt instruction (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
cancel_btn (bool): enable the cancel button or not cancel_btn (gr.Button): the cancel button
reset_btn (bool): enable the reset button or not reset_btn (gr.Button): the reset button
request (gr.Request): the request from a user session_id (int): the session id
""" """
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn, yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
f'{bot_summarized_response}'.strip())
async for outputs in InterFace.async_engine.generate( async for outputs in InterFace.async_engine.generate(
instruction, instruction,
...@@ -57,27 +53,21 @@ async def chat_stream_local( ...@@ -57,27 +53,21 @@ async def chat_stream_local(
state_chatbot[-1] = (state_chatbot[-1][0], state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response state_chatbot[-1][1] + response
) # piece by piece ) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn, yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn, yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
f'{bot_summarized_response}'.strip())
async def reset_local_func(instruction_txtbox: gr.Textbox, async def reset_local_func(instruction_txtbox: gr.Textbox,
state_chatbot: gr.State, request: gr.Request): state_chatbot: Sequence, session_id: int):
"""reset the session. """reset the session.
Args: Args:
instruction_txtbox (str): user's prompt instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user session_id (int): the session id
""" """
state_chatbot = [] state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
async for out in InterFace.async_engine.generate('', async for out in InterFace.async_engine.generate('',
session_id, session_id,
...@@ -86,29 +76,21 @@ async def reset_local_func(instruction_txtbox: gr.Textbox, ...@@ -86,29 +76,21 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
sequence_start=False, sequence_start=False,
sequence_end=True): sequence_end=True):
pass pass
return (state_chatbot, state_chatbot, gr.Textbox.update(value=''))
return (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request): reset_btn: gr.Button, session_id: int):
"""stop the session. """stop the session.
Args: Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
cancel_btn (bool): enable the cancel button or not cancel_btn (gr.Button): the cancel button
reset_btn (bool): enable the reset button or not reset_btn (gr.Button): the reset button
request (gr.Request): the request from a user session_id (int): the session id
""" """
yield (state_chatbot, disable_btn, disable_btn) yield (state_chatbot, disable_btn, enable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('', async for out in InterFace.async_engine.generate('',
session_id, session_id,
request_output_len=0, request_output_len=0,
...@@ -152,6 +134,7 @@ def run_local(model_path: str, ...@@ -152,6 +134,7 @@ def run_local(model_path: str,
with gr.Blocks(css=CSS, theme=THEME) as demo: with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([]) state_chatbot = gr.State([])
state_session_id = gr.State(0)
with gr.Column(elem_id='container'): with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground') gr.Markdown('## LMDeploy Playground')
...@@ -166,24 +149,34 @@ def run_local(model_path: str, ...@@ -166,24 +149,34 @@ def run_local(model_path: str,
cancel_btn = gr.Button(value='Cancel', interactive=False) cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset') reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit( send_event = instruction_txtbox.submit(chat_stream_local, [
chat_stream_local, instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn], state_session_id
[state_chatbot, chatbot, cancel_btn, reset_btn]) ], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit( instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''), lambda: gr.Textbox.update(value=''),
[], [],
[instruction_txtbox], [instruction_txtbox],
) )
cancel_btn.click(cancel_local_func, cancel_btn.click(
[state_chatbot, cancel_btn, reset_btn], cancel_local_func,
[state_chatbot, cancel_btn, reset_btn], [state_chatbot, cancel_btn, reset_btn, state_session_id],
cancels=[send_event]) [state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot],
reset_btn.click(reset_local_func,
[instruction_txtbox, state_chatbot, state_session_id],
[state_chatbot, chatbot, instruction_txtbox], [state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event]) cancels=[send_event])
def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}') print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100, demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch( api_open=True).launch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment