offline_streaming_inference_chat_demo.py 3.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

'''
python offline_streaming_inference_chat_demo.py --model /models/llama2/Llama-2-7b-chat-hf  --dtype float16 --enforce-eager -tp 1 
'''
from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
import asyncio
from transformers import AutoTokenizer
import logging
import argparse
import sys


if __name__ == '__main__':
    vllm_logger = logging.getLogger("vllm")
    vllm_logger.setLevel(logging.WARNING)

    class FlexibleArgumentParser(argparse.ArgumentParser):
        """ArgumentParser that allows both underscore and dash in names."""

        def parse_args(self, args=None, namespace=None):
            if args is None:
                args = sys.argv[1:]

            # Convert underscores to dashes and vice versa in argument names
            processed_args = []
            for arg in args:
                if arg.startswith('--'):
                    if '=' in arg:
                        key, value = arg.split('=', 1)
                        key = '--' + key[len('--'):].replace('_', '-')
                        processed_args.append(f'{key}={value}')
                    else:
                        processed_args.append('--' +
                                            arg[len('--'):].replace('_', '-'))
                else:
                    processed_args.append(arg)

            return super().parse_args(processed_args, namespace)
    
    parser = FlexibleArgumentParser()
    parser = AsyncEngineArgs.add_cli_args(parser)
    args = parser.parse_args()

    # chat = [
    #   {"role": "user", "content": "Hello, how are you?"},
    #   {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    #   {"role": "user", "content": "I'd like to show off how chat templating works!"},
    # ]

    tokenizer =  AutoTokenizer.from_pretrained(args.model)
    # try:
    #      f = open(args.template,'r')
    #      tokenizer.chat_template = f.read()
    # except Exception as e:
    #      print('except:',e)
    # finally:
    #      f.close()



    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = AsyncLLMEngine.from_engine_args(engine_args)


    model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !=""  else args.model.split("/")[-2]
    print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")


    def build_prompt(history):
        prompt = ""
        for query, response in history:
            prompt += f"\n\n用户:{query}"
            prompt += f"\n\n{model_name}:{response}"
        return prompt


    history = []
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break 
        history.append({"role": "user", "content": query})
        new_query = tokenizer.apply_chat_template(history, tokenize=False)
        example_input = {
        "prompt": new_query,
        "stream": False, 
        "temperature": 0.0,
        "request_id": 0,
        }

        results_generator = engine.generate(
        example_input["prompt"],
        SamplingParams(temperature=example_input["temperature"], max_tokens=100),
        example_input["request_id"]
        )

        start = 0
        end = 0
        response = ""
        async def process_results():
            async for  output in results_generator: 
                global end 
                global start 
                global response
                print(output.outputs[0].text[start:], end="", flush=True)
                length = len(output.outputs[0].text)
                start = length
                response = output.outputs[0].text
        
        asyncio.run(process_results())
        history.append({"role": "assistant", "content": response})
    print()