chat_example.py 1.86 KB
Newer Older
q.yao's avatar
q.yao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import fire
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
from transformers import AutoTokenizer
import random


def input_prompt():
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


def main(model_name, model_path, tokenizer_model_path, session_id: int = 1):
    tm_model = tm.TurboMind(model_path)
    generator = tm_model.create_instance()
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_path)
    model = MODELS.get(model_name)()

    nth_round = 1
    step = 0
    seed = random.getrandbits(64)

    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
            pass
        else:
            prompt = model.get_prompt(prompt, nth_round == 1)
            input_ids = tokenizer.encode(prompt, add_special_tokens=False)
            for outputs in generator.stream_infer(
                    session_id=session_id,
                    input_ids=[input_ids],
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False,
                    step=step,
                    stop=False,
                    top_k=40,
                    top_p=0.8,
                    temperature=0.8,
                    repetition_penalty=1.05,
                    ignore_eos=False,
                    random_seed=seed if nth_round == 1 else None):
                res, tokens = outputs[0]
                # decode res
                response = tokenizer.decode(
                    res[step:], skip_special_tokens=True)
                print(f'session {session_id}, {tokens}, {response}')
                # update step
                step = tokens - 1

        nth_round += 1


if __name__ == '__main__':
    fire.Fire(main)