simple_server.py 1.44 KB
Newer Older
1
2
import argparse

3
4
5
from cacheflow.master.server import (
    add_server_arguments, process_server_arguments,
    init_local_server_and_frontend_with_arguments)
6
7
from cacheflow.sampling_params import SamplingParams

Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
def main(args: argparse.Namespace):
10
    server, frontend = init_local_server_and_frontend_with_arguments(args)
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    # Test the following inputs.
    test_inputs = [
        ('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
        ('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
        ('The future of cloud computing is', {}),   # Use default parameters.
    ]
    while True:
        if test_inputs:
            text, sampling_params_dict = test_inputs.pop(0)
            sampling_params = SamplingParams.from_dict(sampling_params_dict)
            sampling_params = frontend.add_eos_token(sampling_params)
            frontend.query(text, sampling_params)
        server.add_sequence_groups(frontend.get_inputs())
        updated_seq_groups = server.step()
        for seq_group in updated_seq_groups:
            if seq_group.is_finished():
                frontend.print_response(seq_group)
        if not (server.has_unfinished_requests() or test_inputs):
            break


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CacheFlow simple server.')
    parser = add_server_arguments(parser)
    args = parser.parse_args()
36
    args = process_server_arguments(args)
37
    main(args)