chat.py 3.63 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
q.yao's avatar
q.yao committed
2
import os
3
import os.path as osp
4
5
import random

q.yao's avatar
q.yao committed
6
import fire
7

q.yao's avatar
q.yao committed
8
9
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
q.yao's avatar
q.yao committed
10
from lmdeploy.turbomind.tokenizer import Tokenizer
q.yao's avatar
q.yao committed
11

q.yao's avatar
q.yao committed
12
13
os.environ['TM_LOG_LEVEL'] = 'ERROR'

q.yao's avatar
q.yao committed
14
15

def input_prompt():
lvhan028's avatar
lvhan028 committed
16
    """Input a prompt in the consolo interface."""
q.yao's avatar
q.yao committed
17
18
19
20
21
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


q.yao's avatar
q.yao committed
22
def valid_str(string, coding='utf-8'):
lvhan028's avatar
lvhan028 committed
23
    """decode text according to its encoding type."""
q.yao's avatar
q.yao committed
24
25
26
27
28
29
30
31
    invalid_chars = [b'\xef\xbf\xbd']
    bstr = bytes(string, coding)
    for invalid_char in invalid_chars:
        bstr = bstr.replace(invalid_char, b'')
    ret = bstr.decode(encoding=coding, errors='ignore')
    return ret


32
def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
lvhan028's avatar
lvhan028 committed
33
34
35
36
37
38
39
    """An example to perform model inference through the command line
    interface.

    Args:
        model_path (str): the path of the deployed model
        session_id (int): the identical id of a session
    """
40
    tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
q.yao's avatar
q.yao committed
41
    tokenizer = Tokenizer(tokenizer_model_path)
42
    tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
q.yao's avatar
q.yao committed
43
    generator = tm_model.create_instance()
q.yao's avatar
q.yao committed
44
45
46
47

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)
48
49
    model_name = tm_model.model_name
    model = MODELS.get(model_name)()
q.yao's avatar
q.yao committed
50
51
52
53
54
55

    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
q.yao's avatar
q.yao committed
56
            prompt = model.get_prompt('', nth_round == 1)
q.yao's avatar
q.yao committed
57
            input_ids = tokenizer.encode(prompt)
q.yao's avatar
q.yao committed
58
59
60
61
62
63
64
65
66
            for outputs in generator.stream_infer(session_id=session_id,
                                                  input_ids=[input_ids],
                                                  request_output_len=512,
                                                  sequence_start=False,
                                                  sequence_end=True):
                pass
            nth_round = 1
            step = 0
            seed = random.getrandbits(64)
q.yao's avatar
q.yao committed
67
        else:
q.yao's avatar
q.yao committed
68
            print(f'session {session_id}')
q.yao's avatar
q.yao committed
69
70
71
72
73
74
75
            if step >= tm_model.session_len:
                print('WARNING: exceed session max length.'
                      ' Please end the session.')
                continue
            prompt = model.get_prompt(prompt, nth_round == 1)
            input_ids = tokenizer.encode(prompt)
            print(f'{prompt} ', end='', flush=True)
q.yao's avatar
q.yao committed
76
            response_size = 0
q.yao's avatar
q.yao committed
77
78
79
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
q.yao's avatar
q.yao committed
80
                    stream_output=True,
q.yao's avatar
q.yao committed
81
82
83
84
85
86
87
88
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False,
                    step=step,
                    stop=False,
                    top_k=40,
                    top_p=0.8,
                    temperature=0.8,
89
                    repetition_penalty=repetition_penalty,
q.yao's avatar
q.yao committed
90
91
92
93
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
q.yao's avatar
q.yao committed
94
                response = tokenizer.decode(res)[response_size:]
q.yao's avatar
q.yao committed
95
96
97
98
99
100
101
                response = valid_str(response)
                print(f'{response}', end='', flush=True)
                response_size += len(response)

            # update step
            step += len(input_ids) + tokens
            print()
q.yao's avatar
q.yao committed
102

q.yao's avatar
q.yao committed
103
            nth_round += 1
q.yao's avatar
q.yao committed
104
105
106
107


if __name__ == '__main__':
    fire.Fire(main)