app.py 4.5 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
# Copyright (c) OpenMMLab. All rights reserved.
AllentDan's avatar
AllentDan committed
2
3
from functools import partial
import threading
AllentDan's avatar
AllentDan committed
4
from typing import Sequence
AllentDan's avatar
AllentDan committed
5
6
7
8
9
10
11

import fire
import gradio as gr
import os

from llmdeploy.serve.fastertransformer.chatbot import Chatbot

AllentDan's avatar
AllentDan committed
12
13
14
15
16
17
CSS = """
#container {
    width: 95%;
    margin-left: auto;
    margin-right: auto;
}
AllentDan's avatar
AllentDan committed
18

AllentDan's avatar
AllentDan committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#chatbot {
    height: 500px;
    overflow: auto;
}

.chat_wrap_space {
    margin-left: 0.5em
}
"""

THEME = gr.themes.Soft(
    primary_hue=gr.themes.colors.blue,
    secondary_hue=gr.themes.colors.sky,
    font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])


def chat_stream(instruction: str,
                state_chatbot: Sequence,
                llama_chatbot: Chatbot,
AllentDan's avatar
AllentDan committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
                model_name: str = None):
    bot_summarized_response = ''
    model_type = 'fastertransformer'
    state_chatbot = state_chatbot + [(instruction, None)]
    session_id = threading.current_thread().ident
    bot_response = llama_chatbot.stream_infer(
        session_id, instruction, f'{session_id}-{len(state_chatbot)}')

    yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())

    for status, tokens, _ in bot_response:
        if state_chatbot[-1][-1] is None or model_type != 'fairscale':
            state_chatbot[-1] = (state_chatbot[-1][0], tokens)
        else:
            state_chatbot[-1] = (state_chatbot[-1][0],
                                 state_chatbot[-1][1] + tokens
                                 )  # piece by piece
        yield (state_chatbot, state_chatbot,
               f'{bot_summarized_response}'.strip())

    yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())


AllentDan's avatar
AllentDan committed
61
62
63
def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
                   llama_chatbot: gr.State, triton_server_addr: str,
                   model_name: str):
AllentDan's avatar
AllentDan committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    state_chatbot = []
    log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
    llama_chatbot = Chatbot(
        triton_server_addr, model_name, log_level=log_level, display=True)

    return (
        llama_chatbot,
        state_chatbot,
        state_chatbot,
        gr.Textbox.update(value=''),
    )


AllentDan's avatar
AllentDan committed
78
79
80
81
82
def cancel_func(
    instruction_txtbox: gr.Textbox,
    state_chatbot: gr.State,
    llama_chatbot: gr.State,
):
AllentDan's avatar
AllentDan committed
83
84
85
86
87
88
89
90
91
92
93
94
95
    session_id = llama_chatbot._session.session_id
    llama_chatbot.cancel(session_id)

    return (
        llama_chatbot,
        state_chatbot,
    )


def run(triton_server_addr: str,
        model_name: str,
        server_name: str = 'localhost',
        server_port: int = 6006):
AllentDan's avatar
AllentDan committed
96
    with gr.Blocks(css=CSS, theme=THEME) as demo:
AllentDan's avatar
AllentDan committed
97
        chat_interface = partial(chat_stream, model_name=model_name)
AllentDan's avatar
AllentDan committed
98
99
        reset_all = partial(
            reset_all_func,
AllentDan's avatar
AllentDan committed
100
101
102
103
104
105
106
107
108
109
110
            model_name=model_name,
            triton_server_addr=triton_server_addr)
        log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
        llama_chatbot = gr.State(
            Chatbot(
                triton_server_addr,
                model_name,
                log_level=log_level,
                display=True))
        state_chatbot = gr.State([])

AllentDan's avatar
AllentDan committed
111
112
        with gr.Column(elem_id='container'):
            gr.Markdown('## LLMDeploy Playground')
AllentDan's avatar
AllentDan committed
113
114
115

            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
            instruction_txtbox = gr.Textbox(
AllentDan's avatar
AllentDan committed
116
                placeholder='Please input the instruction',
AllentDan's avatar
AllentDan committed
117
118
119
120
121
122
123
124
125
126
127
128
                label='Instruction')
            with gr.Row():
                cancel_btn = gr.Button(value='Cancel')
                reset_btn = gr.Button(value='Reset')

        send_event = instruction_txtbox.submit(
            chat_interface,
            [instruction_txtbox, state_chatbot, llama_chatbot],
            [state_chatbot, chatbot],
            batch=False,
            max_batch_size=1,
        )
AllentDan's avatar
AllentDan committed
129
130
        instruction_txtbox.submit(
            lambda: gr.Textbox.update(value=''),
AllentDan's avatar
AllentDan committed
131
132
133
134
135
136
137
138
139
140
            [],
            [instruction_txtbox],
        )

        cancel_btn.click(
            cancel_func, [instruction_txtbox, state_chatbot, llama_chatbot],
            [llama_chatbot, chatbot],
            cancels=[send_event])

        reset_btn.click(
AllentDan's avatar
AllentDan committed
141
            reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
AllentDan's avatar
AllentDan committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
            [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
            cancels=[send_event])

    demo.queue(
        concurrency_count=4, max_size=100, api_open=True).launch(
            max_threads=10,
            share=True,
            server_port=server_port,
            server_name=server_name,
        )


if __name__ == '__main__':
    fire.Fire(run)