chat.py 3.77 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
q.yao's avatar
q.yao committed
2
import os
3
import os.path as osp
4
5
import random

q.yao's avatar
q.yao committed
6
import fire
7

q.yao's avatar
q.yao committed
8
9
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
q.yao's avatar
q.yao committed
10
from lmdeploy.turbomind.tokenizer import Tokenizer
q.yao's avatar
q.yao committed
11

q.yao's avatar
q.yao committed
12
13
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
14
15

def input_prompt():
lvhan028's avatar
lvhan028 committed
16
    """Input a prompt in the consolo interface."""
q.yao's avatar
q.yao committed
17
18
19
20
21
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
22
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
23
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
24
25
26
27
28
29
30
31
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


32
33
34
35
def main(model_name,
         model_path,
         session_id: int = 1,
         repetition_penalty: float = 1.0):
lvhan028's avatar
lvhan028 committed
36
37
38
39
40
41
42
43
    """An example to perform model inference through the command line
    interface.

    Args:
        model_name (str): the name of the deployed model
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
    """
q.yao's avatar
q.yao committed
44
    model = MODELS.get(model_name)()
45
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
46
47
48
49
50
    tokenizer = Tokenizer(tokenizer_model_path)
    tm_model = tm.TurboMind(model_path,
                            eos_id=tokenizer.eos_token_id,
                            stop_words=model.stop_words)
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
51
52
53
54
55
56
57
58
59
60

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)

    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
61
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
62
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
63
64
65
66
67
68
69
70
71
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
                                                  sequence_end=True):
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
72
        else:
q.yao's avatar
q.yao committed
73
            print(f'session {session_id}')
q.yao's avatar
q.yao committed
74
75
76
77
78
79
80
            if step >= tm_model.session_len:
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
            prompt = model.get_prompt(prompt, nth_round == 1)
            input_ids = tokenizer.encode(prompt)
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
81
            response_size = 0
q.yao's avatar
q.yao committed
82
83
84
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
q.yao's avatar
q.yao committed
85
                    stream_output=True,
q.yao's avatar
q.yao committed
86
87
88
89
90
91
92
93
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False,
                    step=step,
                    stop=False,
                    top_k=40,
                    top_p=0.8,
                    temperature=0.8,
94
                    repetition_penalty=repetition_penalty,
q.yao's avatar
q.yao committed
95
96
97
98
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
q.yao's avatar
q.yao committed
99
                response = tokenizer.decode(res)[response_size:]
q.yao's avatar
q.yao committed
100
101
102
103
104
105
106
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
                response_size += len(response)

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
107

q.yao's avatar
q.yao committed
108
            nth_round += 1
q.yao's avatar
q.yao committed
109
110
111
112


if __name__ == '__main__':
    fire.Fire(main)