"vscode:/vscode.git/clone" did not exist on "a2ba46e9f88484dc92841184f7fb8ae6936c9bb8"
Unverified Commit 9c5cac24 authored by zhyncs's avatar zhyncs Committed by GitHub
Browse files

fix: resolve lint error (#650)

parent 5960a6e5
...@@ -10,6 +10,6 @@ Briefly describe the changes made in this PR. ...@@ -10,6 +10,6 @@ Briefly describe the changes made in this PR.
## Checklist ## Checklist
1. Ensure pre-commit or other linting tools are used to fix potential lint issues. 1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues.
2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. 2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
3. Modify documentation as needed, such as docstrings or example tutorials. 3. Modify documentation as needed, such as docstrings or example tutorials.
...@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller") ...@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum): class LoadBalanceMethod(Enum):
"""Load balance method.""" """Load balance method."""
ROUND_ROBIN = auto() ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto() SHORTEST_QUEUE = auto()
...@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum): ...@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum):
@dataclasses.dataclass @dataclasses.dataclass
class WorkerHandle: class WorkerHandle:
"""Store the handle of a data parallel worker.""" """Store the handle of a data parallel worker."""
proc: multiprocessing.Process proc: multiprocessing.Process
queue: multiprocessing.Queue queue: multiprocessing.Queue
...@@ -62,7 +64,8 @@ class ControllerMulti: ...@@ -62,7 +64,8 @@ class ControllerMulti:
self.port_args = port_args self.port_args = port_args
self.model_overide_args = model_overide_args self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str( self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method) server_args.load_balance_method
)
# Init communication # Init communication
context = zmq.Context() context = zmq.Context()
...@@ -85,7 +88,9 @@ class ControllerMulti: ...@@ -85,7 +88,9 @@ class ControllerMulti:
def start_dp_worker(self, dp_worker_id: int): def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size tp_size = self.server_args.tp_size
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False) pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
duplex=False
)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size)) gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue() queue = multiprocessing.Queue()
...@@ -100,7 +105,7 @@ class ControllerMulti: ...@@ -100,7 +105,7 @@ class ControllerMulti:
gpu_ids, gpu_ids,
dp_worker_id, dp_worker_id,
queue, queue,
) ),
) )
proc.start() proc.start()
...@@ -109,10 +114,12 @@ class ControllerMulti: ...@@ -109,10 +114,12 @@ class ControllerMulti:
raise RuntimeError( raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}" f"Initialization failed. controller_init_state: {controller_init_state}"
) )
self.workers.append(WorkerHandle( self.workers.append(
WorkerHandle(
proc=proc, proc=proc,
queue=queue, queue=queue,
)) )
)
def round_robin_scheduler(self, input_requests): def round_robin_scheduler(self, input_requests):
for r in input_requests: for r in input_requests:
......
...@@ -8,7 +8,9 @@ from typing import List ...@@ -8,7 +8,9 @@ from typing import List
import zmq import zmq
from sglang.srt.managers.controller.tp_worker import ( from sglang.srt.managers.controller.tp_worker import (
broadcast_recv_input, launch_tp_servers, ModelTpServer ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process from sglang.srt.utils import kill_parent_process
...@@ -41,7 +43,9 @@ class ControllerSingle: ...@@ -41,7 +43,9 @@ class ControllerSingle:
if not self.is_dp_worker: if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect( self.send_to_detokenizer.connect(
...@@ -128,9 +132,15 @@ def start_controller_process( ...@@ -128,9 +132,15 @@ def start_controller_process(
queue = None queue = None
try: try:
controller = ControllerSingle(server_args, port_args, model_overide_args, controller = ControllerSingle(
gpu_ids, is_data_parallel_worker, server_args,
dp_worker_id, queue) port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment