chat.py 5.41 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
Lyu Han's avatar
Lyu Han committed
2
import dataclasses
q.yao's avatar
q.yao committed
3
import os
4
import os.path as osp
5
6
import random

q.yao's avatar
q.yao committed
7
import fire
8

q.yao's avatar
q.yao committed
9
10
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
11
from lmdeploy.tokenizer import Tokenizer
q.yao's avatar
q.yao committed
12

q.yao's avatar
q.yao committed
13
14
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
15

Lyu Han's avatar
Lyu Han committed
16
17
18
19
20
21
22
23
24
25
26
27
28
@dataclasses.dataclass
class GenParam:
    top_p: float
    top_k: float
    temperature: float
    repetition_penalty: float
    sequence_start: bool = False
    sequence_end: bool = False
    step: int = 0
    request_output_len: int = 512


def input_prompt(model_name):
lvhan028's avatar
lvhan028 committed
29
    """Input a prompt in the consolo interface."""
Lyu Han's avatar
Lyu Han committed
30
31
32
33
34
35
    if model_name == 'codellama':
        print('\nenter !! to end the input >>>\n', end='')
        sentinel = '!!'
    else:
        print('\ndouble enter to end input >>> ', end='')
        sentinel = ''  # ends when this string is seen
q.yao's avatar
q.yao committed
36
37
38
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
39
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
40
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
41
42
43
44
45
46
47
48
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


Lyu Han's avatar
Lyu Han committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_gen_param(cap,
                  sampling_param,
                  nth_round,
                  step,
                  request_output_len=512,
                  **kwargs):
    """return parameters used by token generation."""
    gen_param = GenParam(**dataclasses.asdict(sampling_param),
                         request_output_len=request_output_len)
    # Fix me later. turbomind.py doesn't support None top_k
    if gen_param.top_k is None:
        gen_param.top_k = 40

    if cap == 'chat':
        gen_param.sequence_start = (nth_round == 1)
        gen_param.sequence_end = False
        gen_param.step = step
    else:
        gen_param.sequence_start = True
        gen_param.sequence_end = True
        gen_param.step = 0
    return gen_param


73
74
def main(model_path,
         session_id: int = 1,
Lyu Han's avatar
Lyu Han committed
75
         cap: str = 'chat',
76
         tp=1,
Lyu Han's avatar
Lyu Han committed
77
78
         stream_output=True,
         **kwargs):
lvhan028's avatar
lvhan028 committed
79
80
81
82
83
84
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
Lyu Han's avatar
Lyu Han committed
85
86
        cap (str): the capability of a model. For example, codellama has
            the ability among ['completion', 'infilling', 'chat', 'python']
87
88
        tp (int): GPU number used in tensor parallelism
        stream_output (bool): indicator for streaming output or not
Lyu Han's avatar
Lyu Han committed
89
        **kwarg (dict): other arguments for initializing model's chat template
lvhan028's avatar
lvhan028 committed
90
    """
91
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
92
    tokenizer = Tokenizer(tokenizer_model_path)
93
    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
q.yao's avatar
q.yao committed
94
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
95
96
97
98

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)
99
    model_name = tm_model.model_name
100
    model = MODELS.get(model_name)(capability=cap, **kwargs)
q.yao's avatar
q.yao committed
101

Lyu Han's avatar
Lyu Han committed
102
    print(f'session {session_id}')
q.yao's avatar
q.yao committed
103
    while True:
Lyu Han's avatar
Lyu Han committed
104
        prompt = input_prompt(model_name)
q.yao's avatar
q.yao committed
105
106
107
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
108
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
109
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
110
111
112
113
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
114
115
                                                  sequence_end=True,
                                                  stream_output=stream_output):
q.yao's avatar
q.yao committed
116
117
118
119
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
120
        else:
121
            prompt = model.get_prompt(prompt, nth_round == 1)
122
123
            input_ids = tokenizer.encode(prompt)
            if step + len(input_ids) >= tm_model.session_len:
q.yao's avatar
q.yao committed
124
125
126
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
Lyu Han's avatar
Lyu Han committed
127
128
129
130

            gen_param = get_gen_param(cap, model.sampling_param, nth_round,
                                      step, **kwargs)

q.yao's avatar
q.yao committed
131
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
132
            response_size = 0
q.yao's avatar
q.yao committed
133
134
135
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
136
                    stream_output=stream_output,
Lyu Han's avatar
Lyu Han committed
137
                    **dataclasses.asdict(gen_param),
q.yao's avatar
q.yao committed
138
139
140
141
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
142
                response = tokenizer.decode(res.tolist(), offset=response_size)
143
144
145
146
147
                # utf-8 char at the end means it's a potential unfinished
                # byte sequence, continue to concate it with the next
                # sequence and decode them together
                if response.endswith('�'):
                    continue
q.yao's avatar
q.yao committed
148
149
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
150
                response_size = tokens
q.yao's avatar
q.yao committed
151
152
153
154

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
155

q.yao's avatar
q.yao committed
156
            nth_round += 1
q.yao's avatar
q.yao committed
157
158
159
160


if __name__ == '__main__':
    fire.Fire(main)