app.py 4.58 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
# Copyright (c) OpenMMLab. All rights reserved.
lvhan028's avatar
lvhan028 committed
2
import os
AllentDan's avatar
AllentDan committed
3
import threading
lvhan028's avatar
lvhan028 committed
4
from functools import partial
AllentDan's avatar
AllentDan committed
5
from typing import Sequence
AllentDan's avatar
AllentDan committed
6
7
8
9

import fire
import gradio as gr

10
from lmdeploy.serve.turbomind.chatbot import Chatbot
AllentDan's avatar
AllentDan committed
11

AllentDan's avatar
AllentDan committed
12
13
14
15
16
17
CSS = """
#container {
    width: 95%;
    margin-left: auto;
    margin-right: auto;
}
AllentDan's avatar
AllentDan committed
18

AllentDan's avatar
AllentDan committed
19
20
21
22
23
24
25
26
27
28
29
30
31
#chatbot {
    height: 500px;
    overflow: auto;
}

.chat_wrap_space {
    margin-left: 0.5em
}
"""

THEME = gr.themes.Soft(
    primary_hue=gr.themes.colors.blue,
    secondary_hue=gr.themes.colors.sky,
lvhan028's avatar
lvhan028 committed
32
    font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
AllentDan's avatar
AllentDan committed
33
34
35
36
37


def chat_stream(instruction: str,
                state_chatbot: Sequence,
                llama_chatbot: Chatbot,
AllentDan's avatar
AllentDan committed
38
39
                model_name: str = None):
    bot_summarized_response = ''
40
    model_type = 'turbomind'
AllentDan's avatar
AllentDan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    state_chatbot = state_chatbot + [(instruction, None)]
    session_id = threading.current_thread().ident
    bot_response = llama_chatbot.stream_infer(
        session_id, instruction, f'{session_id}-{len(state_chatbot)}')

    yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())

    for status, tokens, _ in bot_response:
        if state_chatbot[-1][-1] is None or model_type != 'fairscale':
            state_chatbot[-1] = (state_chatbot[-1][0], tokens)
        else:
            state_chatbot[-1] = (state_chatbot[-1][0],
                                 state_chatbot[-1][1] + tokens
                                 )  # piece by piece
        yield (state_chatbot, state_chatbot,
               f'{bot_summarized_response}'.strip())

    yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())


AllentDan's avatar
AllentDan committed
61
62
63
def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
                   llama_chatbot: gr.State, triton_server_addr: str,
                   model_name: str):
AllentDan's avatar
AllentDan committed
64
65
66

    state_chatbot = []
    log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
lvhan028's avatar
lvhan028 committed
67
68
69
70
    llama_chatbot = Chatbot(triton_server_addr,
                            model_name,
                            log_level=log_level,
                            display=True)
AllentDan's avatar
AllentDan committed
71
72
73
74
75
76
77
78
79

    return (
        llama_chatbot,
        state_chatbot,
        state_chatbot,
        gr.Textbox.update(value=''),
    )


AllentDan's avatar
AllentDan committed
80
81
82
83
84
def cancel_func(
    instruction_txtbox: gr.Textbox,
    state_chatbot: gr.State,
    llama_chatbot: gr.State,
):
AllentDan's avatar
AllentDan committed
85
86
87
88
89
90
91
92
93
94
95
96
97
    session_id = llama_chatbot._session.session_id
    llama_chatbot.cancel(session_id)

    return (
        llama_chatbot,
        state_chatbot,
    )


def run(triton_server_addr: str,
        model_name: str,
        server_name: str = 'localhost',
        server_port: int = 6006):
AllentDan's avatar
AllentDan committed
98
    with gr.Blocks(css=CSS, theme=THEME) as demo:
AllentDan's avatar
AllentDan committed
99
        chat_interface = partial(chat_stream, model_name=model_name)
lvhan028's avatar
lvhan028 committed
100
101
102
        reset_all = partial(reset_all_func,
                            model_name=model_name,
                            triton_server_addr=triton_server_addr)
AllentDan's avatar
AllentDan committed
103
104
        log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
        llama_chatbot = gr.State(
lvhan028's avatar
lvhan028 committed
105
106
107
108
            Chatbot(triton_server_addr,
                    model_name,
                    log_level=log_level,
                    display=True))
AllentDan's avatar
AllentDan committed
109
110
        state_chatbot = gr.State([])

AllentDan's avatar
AllentDan committed
111
        with gr.Column(elem_id='container'):
lvhan028's avatar
lvhan028 committed
112
            gr.Markdown('## LMDeploy Playground')
AllentDan's avatar
AllentDan committed
113
114
115

            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
            instruction_txtbox = gr.Textbox(
AllentDan's avatar
AllentDan committed
116
                placeholder='Please input the instruction',
AllentDan's avatar
AllentDan committed
117
118
119
120
121
122
123
124
125
126
127
128
                label='Instruction')
            with gr.Row():
                cancel_btn = gr.Button(value='Cancel')
                reset_btn = gr.Button(value='Reset')

        send_event = instruction_txtbox.submit(
            chat_interface,
            [instruction_txtbox, state_chatbot, llama_chatbot],
            [state_chatbot, chatbot],
            batch=False,
            max_batch_size=1,
        )
AllentDan's avatar
AllentDan committed
129
130
        instruction_txtbox.submit(
            lambda: gr.Textbox.update(value=''),
AllentDan's avatar
AllentDan committed
131
132
133
134
            [],
            [instruction_txtbox],
        )

lvhan028's avatar
lvhan028 committed
135
136
137
138
        cancel_btn.click(cancel_func,
                         [instruction_txtbox, state_chatbot, llama_chatbot],
                         [llama_chatbot, chatbot],
                         cancels=[send_event])
AllentDan's avatar
AllentDan committed
139
140

        reset_btn.click(
AllentDan's avatar
AllentDan committed
141
            reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
AllentDan's avatar
AllentDan committed
142
143
144
            [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
            cancels=[send_event])

lvhan028's avatar
lvhan028 committed
145
146
147
148
149
150
    demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
        max_threads=10,
        share=True,
        server_port=server_port,
        server_name=server_name,
    )
AllentDan's avatar
AllentDan committed
151
152
153
154


if __name__ == '__main__':
    fire.Fire(run)