utils.py 917 Bytes
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from dataclasses import dataclass

from swift.llm import get_default_template_type, get_template, get_vllm_engine, inference_vllm
from swift.utils import get_main


@dataclass
class VLLMTestArgs:
    model_type: str


def test_vllm(args: VLLMTestArgs) -> None:
    model_type = args.model_type
    llm_engine = get_vllm_engine(model_type)
    template_type = get_default_template_type(model_type)
    template = get_template(template_type, llm_engine.hf_tokenizer)

    llm_engine.generation_config.max_new_tokens = 256

    request_list = [{'query': '你好!'}, {'query': '浙江的省会在哪?'}]
    resp_list = inference_vllm(llm_engine, template, request_list)
    for request, resp in zip(request_list, resp_list):
        print(f"query: {request['query']}")
        print(f"response: {resp['response']}")


test_vllm_main = get_main(VLLMTestArgs, test_vllm)

if __name__ == '__main__':
    test_vllm_main()