app.py 4.33 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# flake8: noqa
from functools import partial
import threading

import fire
import gradio as gr
import os
from strings import ABSTRACT, TITLE
from styles import PARENT_BLOCK_CSS

from llmdeploy.serve.fastertransformer.chatbot import Chatbot


def chat_stream(instruction,
                state_chatbot,
                llama_chatbot,
                model_name: str = None):
    bot_summarized_response = ''
    model_type = 'fastertransformer'
    state_chatbot = state_chatbot + [(instruction, None)]
    session_id = threading.current_thread().ident
    bot_response = llama_chatbot.stream_infer(
        session_id, instruction, f'{session_id}-{len(state_chatbot)}')

    yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())

    for status, tokens, _ in bot_response:
        if state_chatbot[-1][-1] is None or model_type != 'fairscale':
            state_chatbot[-1] = (state_chatbot[-1][0], tokens)
        else:
            state_chatbot[-1] = (state_chatbot[-1][0],
                                 state_chatbot[-1][1] + tokens
                                 )  # piece by piece
        yield (state_chatbot, state_chatbot,
               f'{bot_summarized_response}'.strip())

    yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())


def reset_textbox():
    return gr.Textbox.update(value='')


def reset_everything_func(instruction_txtbox, state_chatbot, llama_chatbot,
                          triton_server_addr, model_name):

    state_chatbot = []
    log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
    llama_chatbot = Chatbot(
        triton_server_addr, model_name, log_level=log_level, display=True)

    return (
        llama_chatbot,
        state_chatbot,
        state_chatbot,
        gr.Textbox.update(value=''),
    )


def cancel_func(instruction_txtbox, state_chatbot, llama_chatbot):
    session_id = llama_chatbot._session.session_id
    llama_chatbot.cancel(session_id)

    return (
        llama_chatbot,
        state_chatbot,
    )


def run(triton_server_addr: str,
        model_name: str,
        server_name: str = 'localhost',
        server_port: int = 6006):
    with gr.Blocks(css=PARENT_BLOCK_CSS, theme='ParityError/Anime') as demo:
        chat_interface = partial(chat_stream, model_name=model_name)
        reset_everything = partial(
            reset_everything_func,
            model_name=model_name,
            triton_server_addr=triton_server_addr)
        log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
        llama_chatbot = gr.State(
            Chatbot(
                triton_server_addr,
                model_name,
                log_level=log_level,
                display=True))
        state_chatbot = gr.State([])

        with gr.Column(elem_id='col_container'):
            gr.Markdown(f'## {TITLE}\n\n\n{ABSTRACT}')

            # with gr.Accordion('Context Setting', open=False):
            #     hidden_txtbox = gr.Textbox(
            #         placeholder='', label='Order', visible=False)

            chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
            instruction_txtbox = gr.Textbox(
                placeholder='What do you want to say to AI?',
                label='Instruction')
            with gr.Row():
                cancel_btn = gr.Button(value='Cancel')
                reset_btn = gr.Button(value='Reset')

        send_event = instruction_txtbox.submit(
            chat_interface,
            [instruction_txtbox, state_chatbot, llama_chatbot],
            [state_chatbot, chatbot],
            batch=False,
            max_batch_size=1,
        )
        reset_event = instruction_txtbox.submit(
            reset_textbox,
            [],
            [instruction_txtbox],
        )

        cancel_btn.click(
            cancel_func, [instruction_txtbox, state_chatbot, llama_chatbot],
            [llama_chatbot, chatbot],
            cancels=[send_event])

        reset_btn.click(
            reset_everything,
            [instruction_txtbox, state_chatbot, llama_chatbot],
            [llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
            cancels=[send_event])

    demo.queue(
        concurrency_count=4, max_size=100, api_open=True).launch(
            max_threads=10,
            share=True,
            server_port=server_port,
            server_name=server_name,
        )


if __name__ == '__main__':
    fire.Fire(run)