chat.py 3.97 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
q.yao's avatar
q.yao committed
2
import os
3
import os.path as osp
4
5
import random

q.yao's avatar
q.yao committed
6
import fire
7

q.yao's avatar
q.yao committed
8
9
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
q.yao's avatar
q.yao committed
10
from lmdeploy.turbomind.tokenizer import Tokenizer
q.yao's avatar
q.yao committed
11

q.yao's avatar
q.yao committed
12
13
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
14
15

def input_prompt():
lvhan028's avatar
lvhan028 committed
16
    """Input a prompt in the consolo interface."""
q.yao's avatar
q.yao committed
17
18
19
20
21
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
22
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
23
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
24
25
26
27
28
29
30
31
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


32
33
34
def main(model_path,
         session_id: int = 1,
         repetition_penalty: float = 1.0,
35
36
         tp=1,
         stream_output=True):
lvhan028's avatar
lvhan028 committed
37
38
39
40
41
42
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
43
44
45
        repetition_penalty (float): parameter to penalize repetition
        tp (int): GPU number used in tensor parallelism
        stream_output (bool): indicator for streaming output or not
lvhan028's avatar
lvhan028 committed
46
    """
47
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
48
    tokenizer = Tokenizer(tokenizer_model_path)
49
    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
q.yao's avatar
q.yao committed
50
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
51
52
53
54

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)
55
56
    model_name = tm_model.model_name
    model = MODELS.get(model_name)()
q.yao's avatar
q.yao committed
57
58
59
60
61
62

    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
63
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
64
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
65
66
67
68
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
69
70
                                                  sequence_end=True,
                                                  stream_output=stream_output):
q.yao's avatar
q.yao committed
71
72
73
74
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
75
        else:
q.yao's avatar
q.yao committed
76
            print(f'session {session_id}')
q.yao's avatar
q.yao committed
77
78
79
80
81
82
83
            if step >= tm_model.session_len:
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
            prompt = model.get_prompt(prompt, nth_round == 1)
            input_ids = tokenizer.encode(prompt)
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
84
            response_size = 0
q.yao's avatar
q.yao committed
85
86
87
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
88
                    stream_output=stream_output,
q.yao's avatar
q.yao committed
89
90
91
92
93
94
95
96
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False,
                    step=step,
                    stop=False,
                    top_k=40,
                    top_p=0.8,
                    temperature=0.8,
97
                    repetition_penalty=repetition_penalty,
q.yao's avatar
q.yao committed
98
99
100
101
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
q.yao's avatar
q.yao committed
102
                response = tokenizer.decode(res)[response_size:]
q.yao's avatar
q.yao committed
103
104
105
106
107
108
109
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
                response_size += len(response)

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
110

q.yao's avatar
q.yao committed
111
            nth_round += 1
q.yao's avatar
q.yao committed
112
113
114
115


if __name__ == '__main__':
    fire.Fire(main)