Unverified Commit 11d10930 authored by aisensiy's avatar aisensiy Committed by GitHub
Browse files

Manage session id using random int for gradio local mode (#553)



* Use session id from gradio state

* use a new session id after reset

* rename session id like a state

* update comments

* reformat files

* init session id on block loaded

* use auto increased session id

* remove session id textbox

* apply to api_server and tritonserver

* update docstring

* add lock for safety

---------
Co-authored-by: default avatarAllentDan <AllentDan@yeah.net>
parent 85d2f662
# Copyright (c) OpenMMLab. All rights reserved.
import threading
import time
from threading import Lock
from typing import Sequence
import gradio as gr
......@@ -8,35 +8,27 @@ import gradio as gr
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id
class InterFace:
api_server_url: str = None
global_session_id: int = 0
lock = Lock()
def chat_stream_restful(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
def chat_stream_restful(instruction: str, state_chatbot: Sequence,
cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
for response, tokens, finish_reason in get_streaming_response(
instruction,
......@@ -56,27 +48,21 @@ def chat_stream_restful(
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request):
session_id: int):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
......@@ -94,18 +80,15 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
reset_btn: gr.Button, session_id: int):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for out in get_streaming_response(
'',
......@@ -152,6 +135,7 @@ def run_api_server(api_server_url: str,
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
state_session_id = gr.State(0)
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
......@@ -164,25 +148,34 @@ def run_api_server(api_server_url: str,
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_restful,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
send_event = instruction_txtbox.submit(chat_stream_restful, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
cancel_btn.click(
cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn, state_session_id],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot, state_session_id],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import threading
from functools import partial
from threading import Lock
from typing import Sequence
import gradio as gr
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot
class InterFace:
global_session_id: int = 0
lock = Lock()
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button,
request: gr.Request):
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
"""Chat with AI assistant.
Args:
......@@ -22,12 +25,9 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
llama_chatbot (Chatbot): the instance of a chatbot
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
session_id (int): the session id
"""
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
......@@ -92,6 +92,7 @@ def run_triton_server(triton_server_addr: str,
llama_chatbot = gr.State(
Chatbot(triton_server_addr, log_level=log_level, display=True))
state_chatbot = gr.State([])
state_session_id = gr.State(0)
model_name = llama_chatbot.value.model_name
reset_all = partial(reset_all_func,
model_name=model_name,
......@@ -110,10 +111,10 @@ def run_triton_server(triton_server_addr: str,
send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(
chat_stream,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
[instruction_txtbox, state_chatbot]).then(chat_stream, [
state_chatbot, llama_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])
cancel_btn.click(cancel_func,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
......@@ -125,6 +126,14 @@ def run_triton_server(triton_server_addr: str,
[llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
......
# Copyright (c) OpenMMLab. All rights reserved.
import threading
from threading import Lock
from typing import Sequence
import gradio as gr
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
class InterFace:
async_engine: AsyncEngine = None
global_session_id: int = 0
lock = Lock()
async def chat_stream_local(
......@@ -18,25 +19,20 @@ async def chat_stream_local(
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
session_id: int,
):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
cancel_btn (gr.Button): the cancel button
reset_btn (gr.Button): the reset button
session_id (int): the session id
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
async for outputs in InterFace.async_engine.generate(
instruction,
......@@ -57,27 +53,21 @@ async def chat_stream_local(
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
async def reset_local_func(instruction_txtbox: gr.Textbox,
state_chatbot: gr.State, request: gr.Request):
state_chatbot: Sequence, session_id: int):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
......@@ -86,29 +76,21 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
sequence_start=False,
sequence_end=True):
pass
return (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
return (state_chatbot, state_chatbot, gr.Textbox.update(value=''))
async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
reset_btn: gr.Button, session_id: int):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
cancel_btn (gr.Button): the cancel button
reset_btn (gr.Button): the reset button
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
yield (state_chatbot, disable_btn, enable_btn)
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=0,
......@@ -152,6 +134,7 @@ def run_local(model_path: str,
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
state_session_id = gr.State(0)
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
......@@ -166,24 +149,34 @@ def run_local(model_path: str,
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_local,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
send_event = instruction_txtbox.submit(chat_stream_local, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_local_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot],
cancel_btn.click(
cancel_local_func,
[state_chatbot, cancel_btn, reset_btn, state_session_id],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_local_func,
[instruction_txtbox, state_chatbot, state_session_id],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment