"examples/advanced_diffusion_training/requirements.txt" did not exist on "c2717317f03b12535cfd02b477ace61189e10e4b"
simple_server.py 1.49 KB
Newer Older
1
2
import argparse

3
from cacheflow.core.server import (
4
5
    add_server_arguments, process_server_arguments,
    init_local_server_and_frontend_with_arguments)
6
7
from cacheflow.sampling_params import SamplingParams

Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
def main(args: argparse.Namespace):
10
    server, frontend = init_local_server_and_frontend_with_arguments(args)
11
12
    # Test the following inputs.
    test_inputs = [
13
14
15
        ("A robot may not injure a human being", {}),   # Use default parameters.
        ("What is the meaning of life?", {"n": 3, "temperature": 0.8, "top_p": 0.99}),
        ("It is only with the heart that one can see rightly", {"n": 4, "use_beam_search": True, "temperature": 0.0}),
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    ]
    while True:
        if test_inputs:
            text, sampling_params_dict = test_inputs.pop(0)
            sampling_params = SamplingParams.from_dict(sampling_params_dict)
            sampling_params = frontend.add_eos_token(sampling_params)
            frontend.query(text, sampling_params)
        server.add_sequence_groups(frontend.get_inputs())
        updated_seq_groups = server.step()
        for seq_group in updated_seq_groups:
            if seq_group.is_finished():
                frontend.print_response(seq_group)
        if not (server.has_unfinished_requests() or test_inputs):
            break


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CacheFlow simple server.')
    parser = add_server_arguments(parser)
    args = parser.parse_args()
36
    args = process_server_arguments(args)
37
    main(args)