chat.py 5.43 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
Lyu Han's avatar
Lyu Han committed
2
import dataclasses
q.yao's avatar
q.yao committed
3
import os
4
import os.path as osp
5
6
import random

q.yao's avatar
q.yao committed
7
8
from lmdeploy.model import MODELS

q.yao's avatar
q.yao committed
9
10
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
11

Lyu Han's avatar
Lyu Han committed
12
13
14
15
16
17
18
19
20
21
22
23
24
@dataclasses.dataclass
class GenParam:
    top_p: float
    top_k: float
    temperature: float
    repetition_penalty: float
    sequence_start: bool = False
    sequence_end: bool = False
    step: int = 0
    request_output_len: int = 512


def input_prompt(model_name):
lvhan028's avatar
lvhan028 committed
25
    """Input a prompt in the consolo interface."""
Lyu Han's avatar
Lyu Han committed
26
27
28
29
30
31
    if model_name == 'codellama':
        print('\nenter !! to end the input >>>\n', end='')
        sentinel = '!!'
    else:
        print('\ndouble enter to end input >>> ', end='')
        sentinel = ''  # ends when this string is seen
q.yao's avatar
q.yao committed
32
33
34
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
35
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
36
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
37
38
39
40
41
42
43
44
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


Lyu Han's avatar
Lyu Han committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def get_gen_param(cap,
                  sampling_param,
                  nth_round,
                  step,
                  request_output_len=512,
                  **kwargs):
    """return parameters used by token generation."""
    gen_param = GenParam(**dataclasses.asdict(sampling_param),
                         request_output_len=request_output_len)
    # Fix me later. turbomind.py doesn't support None top_k
    if gen_param.top_k is None:
        gen_param.top_k = 40

    if cap == 'chat':
        gen_param.sequence_start = (nth_round == 1)
        gen_param.sequence_end = False
        gen_param.step = step
    else:
        gen_param.sequence_start = True
        gen_param.sequence_end = True
        gen_param.step = 0
    return gen_param


69
70
def main(model_path,
         session_id: int = 1,
Lyu Han's avatar
Lyu Han committed
71
         cap: str = 'chat',
72
         tp=1,
Lyu Han's avatar
Lyu Han committed
73
74
         stream_output=True,
         **kwargs):
lvhan028's avatar
lvhan028 committed
75
76
77
78
79
80
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
Lyu Han's avatar
Lyu Han committed
81
82
        cap (str): the capability of a model. For example, codellama has
            the ability among ['completion', 'infilling', 'chat', 'python']
83
84
        tp (int): GPU number used in tensor parallelism
        stream_output (bool): indicator for streaming output or not
Lyu Han's avatar
Lyu Han committed
85
        **kwarg (dict): other arguments for initializing model's chat template
lvhan028's avatar
lvhan028 committed
86
    """
87
88
89
    from lmdeploy import turbomind as tm
    from lmdeploy.tokenizer import Tokenizer

90
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
91
    tokenizer = Tokenizer(tokenizer_model_path)
92
    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
q.yao's avatar
q.yao committed
93
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
94
95
96
97

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)
98
    model_name = tm_model.model_name
99
    model = MODELS.get(model_name)(capability=cap, **kwargs)
q.yao's avatar
q.yao committed
100

Lyu Han's avatar
Lyu Han committed
101
    print(f'session {session_id}')
q.yao's avatar
q.yao committed
102
    while True:
Lyu Han's avatar
Lyu Han committed
103
        prompt = input_prompt(model_name)
q.yao's avatar
q.yao committed
104
105
106
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
107
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
108
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
109
110
111
112
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
113
114
                                                  sequence_end=True,
                                                  stream_output=stream_output):
q.yao's avatar
q.yao committed
115
116
117
118
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
119
        else:
120
            prompt = model.get_prompt(prompt, nth_round == 1)
121
122
            input_ids = tokenizer.encode(prompt)
            if step + len(input_ids) >= tm_model.session_len:
q.yao's avatar
q.yao committed
123
124
125
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
Lyu Han's avatar
Lyu Han committed
126
127
128
129

            gen_param = get_gen_param(cap, model.sampling_param, nth_round,
                                      step, **kwargs)

q.yao's avatar
q.yao committed
130
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
131
            response_size = 0
q.yao's avatar
q.yao committed
132
133
134
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
135
                    stream_output=stream_output,
Lyu Han's avatar
Lyu Han committed
136
                    **dataclasses.asdict(gen_param),
q.yao's avatar
q.yao committed
137
138
139
140
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
141
                response = tokenizer.decode(res.tolist(), offset=response_size)
142
143
144
145
146
                # utf-8 char at the end means it's a potential unfinished
                # byte sequence, continue to concate it with the next
                # sequence and decode them together
                if response.endswith('�'):
                    continue
q.yao's avatar
q.yao committed
147
148
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
149
                response_size = tokens
q.yao's avatar
q.yao committed
150
151
152
153

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
154

q.yao's avatar
q.yao committed
155
            nth_round += 1
q.yao's avatar
q.yao committed
156
157
158


if __name__ == '__main__':
159
160
    import fire

q.yao's avatar
q.yao committed
161
    fire.Fire(main)