Unverified Commit 471650de authored by lambert0312's avatar lambert0312 Committed by GitHub
Browse files

Fix broadcast use cuda device lead to memory capacity unbalanced (#5416)

parent d06a83fb
......@@ -31,6 +31,7 @@ from sglang.srt.utils import is_port_available
logger = logging.getLogger(__name__)
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
......@@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return available_ports
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
......@@ -265,7 +267,9 @@ class MooncakeKVManager(BaseKVManager):
)
if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
continue
if kv_chunk.is_last:
......@@ -279,7 +283,9 @@ class MooncakeKVManager(BaseKVManager):
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
self.transfer_infos.pop(req.room)
except queue.Empty:
......@@ -443,13 +449,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
prefill_info = response.json()
return prefill_info
else:
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
......@@ -466,17 +473,25 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
if prefill_info is None:
logger.error(
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
logger.error(
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
prefill_info
)
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
if prefill_info:
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
self.prefill_server_url = (
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
)
logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
logger.info(
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
self.handshake_prefill_server(kv_indices, aux_index)
def handshake_prefill_server(
......@@ -598,8 +613,13 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
# Add lock to make sure thread-safe
if role == "Prefill":
async with self.lock:
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")
self.prefill_port_table[tp_rank] = {
"serve_ip": serve_ip,
"serve_port": serve_port,
}
logger.info(
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}"
)
return web.Response(text="OK", status=200)
......
......@@ -118,6 +118,7 @@ class VerlEngine:
rank=self._tp_rank,
dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(),
force_cpu_device=False,
)
return output
......
......@@ -846,9 +846,12 @@ def broadcast_pyobj(
rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
force_cpu_device: bool = True,
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)
if rank == 0:
if len(data) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment