Unverified Commit 9c5cac24 authored by zhyncs's avatar zhyncs Committed by GitHub
Browse files

fix: resolve lint error (#650)

parent 5960a6e5
......@@ -10,6 +10,6 @@ Briefly describe the changes made in this PR.
## Checklist
1. Ensure pre-commit or other linting tools are used to fix potential lint issues.
1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues.
2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
3. Modify documentation as needed, such as docstrings or example tutorials.
......@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
......@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum):
@dataclasses.dataclass
class WorkerHandle:
"""Store the handle of a data parallel worker."""
proc: multiprocessing.Process
queue: multiprocessing.Queue
......@@ -62,7 +64,8 @@ class ControllerMulti:
self.port_args = port_args
self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method)
server_args.load_balance_method
)
# Init communication
context = zmq.Context()
......@@ -85,7 +88,9 @@ class ControllerMulti:
def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
duplex=False
)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue()
......@@ -100,7 +105,7 @@ class ControllerMulti:
gpu_ids,
dp_worker_id,
queue,
)
),
)
proc.start()
......@@ -109,10 +114,12 @@ class ControllerMulti:
raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}"
)
self.workers.append(WorkerHandle(
proc=proc,
queue=queue,
))
self.workers.append(
WorkerHandle(
proc=proc,
queue=queue,
)
)
def round_robin_scheduler(self, input_requests):
for r in input_requests:
......
......@@ -8,7 +8,9 @@ from typing import List
import zmq
from sglang.srt.managers.controller.tp_worker import (
broadcast_recv_input, launch_tp_servers, ModelTpServer
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
......@@ -41,7 +43,9 @@ class ControllerSingle:
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
......@@ -128,9 +132,15 @@ def start_controller_process(
queue = None
try:
controller = ControllerSingle(server_args, port_args, model_overide_args,
gpu_ids, is_data_parallel_worker,
dp_worker_id, queue)
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment