vllm_test.py 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
    ## init models
    # huggingface
    logger.info("Starting initial model of Llama - vllm")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    # vllm
    sampling_params = SamplingParams(temperature=1,
                                    top_p=0.95,
                                    max_tokens=1024,
                                    early_stopping=False,
                                    stop_token_ids=[tokenizer.eos_token_id]
                                    )

    # vLLM基础配置
    args = AsyncEngineArgs(model_path)
    args.worker_use_ray = False
    args.engine_use_ray = False
    args.tokenizer = model_path
    args.tensor_parallel_size = tensor_parallel_size
    args.trust_remote_code = True
    args.enforce_eager = True
    args.max_model_len = 1024
    args.dtype = 'float16'
    # 加载模型
    engine = AsyncLLMEngine.from_engine_args(args)
    return engine, tokenizer, sampling_params

def llm_inference(args):
    '''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
    config = configparser.ConfigParser()
    config.read(args.config_path)

    bind_port = int(config['default']['bind_port'])
    model_path = config['llm']['local_llm_path']
    tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
    use_vllm = config.getboolean('llm', 'use_vllm')
    stream_chat = config.getboolean('llm', 'stream_chat')
    logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")

    model, tokenizer, sampling_params = init_model(model_path, tensor_parallel_size)

    async def inference(request):
        start = time.time()
        input_json = await request.json()

        prompt = input_json['query']
        history = input_json['history']
        messages = [{"role": "user", "content": prompt}]
        logger.info("****************** use vllm ******************")
        logger.info(f"before generate {messages}")
        ## 1
        text = tokenizer.apply_chat_template(
                            messages, tokenize=False, add_generation_prompt=True)
        print(text)
        assert model is not None
        request_id = str(uuid.uuid4().hex)
        results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)

        # Streaming case
        async def stream_results() -> AsyncGenerator[bytes, None]:
            async for request_output in results_generator:
                prompt = request_output.prompt
                text_outputs = [output.text for output in request_output.outputs]
                ret = {"text": text_outputs}
                yield (json.dumps(ret) + "\0").encode("utf-8")

        if stream_chat:
            return StreamingResponse(stream_results())

        # Non-streaming case
        final_output = None
        async for request_output in results_generator:
            # if await request.is_disconnected():
            #     # Abort the request if the client disconnects.
            #     await engine.abort(request_id)
            #     return Response(status_code=499)
            final_output = request_output

        assert final_output is not None

        text = [output.text for output in final_output.outputs]
        end = time.time()
        logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
        return web.json_response({'text': text})

    app = web.Application()
    app.add_routes([web.post('/inference', inference)])
    web.run_app(app, host='0.0.0.0', port=bind_port)