"vscode:/vscode.git/clone" did not exist on "b2e95f62b42692403bb691ae743f802e1eee7190"
manager_single.py 2.88 KB
Newer Older
1
"""A controller that manages a group of tensor parallel workers."""
Lianmin Zheng's avatar
Lianmin Zheng committed
2
3
4
5
6
7
import asyncio
import logging

import uvloop
import zmq
import zmq.asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
8

Lianmin Zheng's avatar
Lianmin Zheng committed
9
from sglang.global_config import global_config
10
from sglang.srt.managers.controller.tp_worker import ModelTpClient
Lianmin Zheng's avatar
Lianmin Zheng committed
11
from sglang.srt.server_args import PortArgs, ServerArgs
12
from sglang.utils import get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
13
14
15
16

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


17
18
class ControllerSingle:
    def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        # Init communication
        context = zmq.asyncio.Context(2)
        self.recv_from_tokenizer = context.socket(zmq.PULL)
        self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")

        self.send_to_detokenizer = context.socket(zmq.PUSH)
        self.send_to_detokenizer.connect(
            f"tcp://127.0.0.1:{port_args.detokenizer_port}"
        )

        # Init status
        self.model_client = model_client
        self.recv_reqs = []

33
        # Init some configs
34
        self.request_dependency_delay = global_config.request_dependency_delay
35

Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39
40
41
42
43
44
    async def loop_for_forward(self):
        while True:
            next_step_input = list(self.recv_reqs)
            self.recv_reqs = []
            out_pyobjs = await self.model_client.step(next_step_input)

            for obj in out_pyobjs:
                self.send_to_detokenizer.send_pyobj(obj)

45
            # async sleep for receiving the subsequent request and avoiding cache miss
46
            slept = False
47
            if len(out_pyobjs) != 0:
48
                has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
49
                if has_finished:
50
                    if self.request_dependency_delay > 0:
51
                        slept = True
52
                        await asyncio.sleep(self.request_dependency_delay)
53

54
            if not slept:
55
                await asyncio.sleep(global_config.wait_for_new_request_delay)
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58
59
60
61
62

    async def loop_for_recv_requests(self):
        while True:
            recv_req = await self.recv_from_tokenizer.recv_pyobj()
            self.recv_reqs.append(recv_req)


63
def start_controller_process(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
64
    server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67
68
69
70
71
):
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

    try:
72
73
74
75
76
77
78
        model_client = ModelTpClient(
            list(range(server_args.tp_size)),
            server_args,
            port_args.model_port_args[0],
            model_overide_args,
        )
        controller = ControllerSingle(model_client, port_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
81
82
83
84
85
86
    except Exception:
        pipe_writer.send(get_exception_traceback())
        raise

    pipe_writer.send("init ok")

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
87
88
    loop.create_task(controller.loop_for_recv_requests())
    loop.run_until_complete(controller.loop_for_forward())