chat.py 5.42 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
Lyu Han's avatar
Lyu Han committed
2
import dataclasses
q.yao's avatar
q.yao committed
3
import os
4
import os.path as osp
5
6
import random

q.yao's avatar
q.yao committed
7
import fire
8

q.yao's avatar
q.yao committed
9
10
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
q.yao's avatar
q.yao committed
11
from lmdeploy.turbomind.tokenizer import Tokenizer
q.yao's avatar
q.yao committed
12

q.yao's avatar
q.yao committed
13
14
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
15

Lyu Han's avatar
Lyu Han committed
16
17
18
19
20
21
22
23
24
25
26
27
28
@dataclasses.dataclass
class GenParam:
    top_p: float
    top_k: float
    temperature: float
    repetition_penalty: float
    sequence_start: bool = False
    sequence_end: bool = False
    step: int = 0
    request_output_len: int = 512


def input_prompt(model_name):
lvhan028's avatar
lvhan028 committed
29
    """Input a prompt in the consolo interface."""
Lyu Han's avatar
Lyu Han committed
30
31
32
33
34
35
    if model_name == 'codellama':
        print('\nenter !! to end the input >>>\n', end='')
        sentinel = '!!'
    else:
        print('\ndouble enter to end input >>> ', end='')
        sentinel = ''  # ends when this string is seen
q.yao's avatar
q.yao committed
36
37
38
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
39
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
40
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
41
42
43
44
45
46
47
48
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


Lyu Han's avatar
Lyu Han committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_gen_param(cap,
                  sampling_param,
                  nth_round,
                  step,
                  request_output_len=512,
                  **kwargs):
    """return parameters used by token generation."""
    gen_param = GenParam(**dataclasses.asdict(sampling_param),
                         request_output_len=request_output_len)
    # Fix me later. turbomind.py doesn't support None top_k
    if gen_param.top_k is None:
        gen_param.top_k = 40

    if cap == 'chat':
        gen_param.sequence_start = (nth_round == 1)
        gen_param.sequence_end = False
        gen_param.step = step
    else:
        gen_param.sequence_start = True
        gen_param.sequence_end = True
        gen_param.step = 0
    return gen_param


73
74
def main(model_path,
         session_id: int = 1,
Lyu Han's avatar
Lyu Han committed
75
76
         cap: str = 'chat',
         sys_instruct: str = None,
77
         tp=1,
Lyu Han's avatar
Lyu Han committed
78
79
         stream_output=True,
         **kwargs):
lvhan028's avatar
lvhan028 committed
80
81
82
83
84
85
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
Lyu Han's avatar
Lyu Han committed
86
87
88
89
        cap (str): the capability of a model. For example, codellama has
            the ability among ['completion', 'infilling', 'chat', 'python']
        sys_instruct (str): the content of 'system' role, which is used by
            conversational model
90
91
        tp (int): GPU number used in tensor parallelism
        stream_output (bool): indicator for streaming output or not
Lyu Han's avatar
Lyu Han committed
92
        **kwarg (dict): other arguments for initializing model's chat template
lvhan028's avatar
lvhan028 committed
93
    """
94
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
95
    tokenizer = Tokenizer(tokenizer_model_path)
96
    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
q.yao's avatar
q.yao committed
97
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
98
99
100
101

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)
102
    model_name = tm_model.model_name
Lyu Han's avatar
Lyu Han committed
103
104
105
    model = MODELS.get(model_name)(capability=cap, **kwargs) \
        if sys_instruct is None else MODELS.get(model_name)(
            capability=cap, system=sys_instruct, **kwargs)
q.yao's avatar
q.yao committed
106

Lyu Han's avatar
Lyu Han committed
107
    print(f'session {session_id}')
q.yao's avatar
q.yao committed
108
    while True:
Lyu Han's avatar
Lyu Han committed
109
        prompt = input_prompt(model_name)
q.yao's avatar
q.yao committed
110
111
112
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
113
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
114
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
115
116
117
118
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
119
120
                                                  sequence_end=True,
                                                  stream_output=stream_output):
q.yao's avatar
q.yao committed
121
122
123
124
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
125
        else:
126
            prompt = model.get_prompt(prompt, nth_round == 1)
127
128
            input_ids = tokenizer.encode(prompt)
            if step + len(input_ids) >= tm_model.session_len:
q.yao's avatar
q.yao committed
129
130
131
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
Lyu Han's avatar
Lyu Han committed
132
133
134
135

            gen_param = get_gen_param(cap, model.sampling_param, nth_round,
                                      step, **kwargs)

q.yao's avatar
q.yao committed
136
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
137
            response_size = 0
q.yao's avatar
q.yao committed
138
139
140
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
141
                    stream_output=stream_output,
Lyu Han's avatar
Lyu Han committed
142
                    **dataclasses.asdict(gen_param),
q.yao's avatar
q.yao committed
143
144
145
146
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
147
                response = tokenizer.decode(res.tolist(), offset=response_size)
q.yao's avatar
q.yao committed
148
149
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
150
                response_size = tokens
q.yao's avatar
q.yao committed
151
152
153
154

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
155

q.yao's avatar
q.yao committed
156
            nth_round += 1
q.yao's avatar
q.yao committed
157
158
159
160


if __name__ == '__main__':
    fire.Fire(main)