test_vllm_utils.py 1.12 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import unittest

import torch

from swift.llm.utils import *
from swift.utils import lower_bound, seed_everything

SKPT_TEST = True


class TestVllmUtils(unittest.TestCase):

    @unittest.skipIf(SKPT_TEST, 'To avoid citest error: OOM')
    def test_inference_vllm(self):
        model_type = ModelType.qwen_7b_chat
        llm_engine = get_vllm_engine(model_type, torch.float16)
        template_type = get_default_template_type(model_type)
        template = get_template(template_type, llm_engine.hf_tokenizer)
        request_list = [{'query': '浙江的省会在哪?'}, {'query': '你好!'}]
        # test inference_vllm
        response_list = inference_vllm(llm_engine, template, request_list, verbose=True)
        for response in response_list:
            print(response)

        # test inference_stream_vllm
        gen = inference_stream_vllm(llm_engine, template, request_list)
        for response_list in gen:
            print(response_list[0]['response'], response_list[0]['history'])
            print(response_list[1]['response'], response_list[1]['history'])


if __name__ == '__main__':
    unittest.main()