chat.py 3.7 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
q.yao's avatar
q.yao committed
2
import os
3
import os.path as osp
4
5
import random

q.yao's avatar
q.yao committed
6
import fire
7

q.yao's avatar
q.yao committed
8
9
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
q.yao's avatar
q.yao committed
10
from lmdeploy.turbomind.tokenizer import Tokenizer
q.yao's avatar
q.yao committed
11

q.yao's avatar
q.yao committed
12
13
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
14
15

def input_prompt():
lvhan028's avatar
lvhan028 committed
16
    """Input a prompt in the consolo interface."""
q.yao's avatar
q.yao committed
17
18
19
20
21
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
22
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
23
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
24
25
26
27
28
29
30
31
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


32
def main(model_name, model_path, session_id: int = 1):
lvhan028's avatar
lvhan028 committed
33
34
35
36
37
38
39
40
    """An example to perform model inference through the command line
    interface.

    Args:
        model_name (str): the name of the deployed model
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
    """
q.yao's avatar
q.yao committed
41
    model = MODELS.get(model_name)()
42
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
43
44
45
46
47
    tokenizer = Tokenizer(tokenizer_model_path)
    tm_model = tm.TurboMind(model_path,
                            eos_id=tokenizer.eos_token_id,
                            stop_words=model.stop_words)
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
48
49
50
51
52
53
54
55
56
57

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)

    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
58
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
59
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
60
61
62
63
64
65
66
67
68
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
                                                  sequence_end=True):
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
69
        else:
q.yao's avatar
q.yao committed
70
            print(f'session {session_id}')
q.yao's avatar
q.yao committed
71
72
73
74
75
76
77
            if step >= tm_model.session_len:
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
            prompt = model.get_prompt(prompt, nth_round == 1)
            input_ids = tokenizer.encode(prompt)
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
78
            response_size = 0
q.yao's avatar
q.yao committed
79
80
81
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
q.yao's avatar
q.yao committed
82
                    stream_output=True,
q.yao's avatar
q.yao committed
83
84
85
86
87
88
89
90
91
92
93
94
95
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False,
                    step=step,
                    stop=False,
                    top_k=40,
                    top_p=0.8,
                    temperature=0.8,
                    repetition_penalty=1.05,
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
q.yao's avatar
q.yao committed
96
                response = tokenizer.decode(res)[response_size:]
q.yao's avatar
q.yao committed
97
98
99
100
101
102
103
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
                response_size += len(response)

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
104

q.yao's avatar
q.yao committed
105
            nth_round += 1
q.yao's avatar
q.yao committed
106
107
108
109


if __name__ == '__main__':
    fire.Fire(main)