Unverified Commit 471650de authored by lambert0312's avatar lambert0312 Committed by GitHub
Browse files

Fix broadcast use cuda device lead to memory capacity unbalanced (#5416)

parent d06a83fb
...@@ -31,6 +31,7 @@ from sglang.srt.utils import is_port_available ...@@ -31,6 +31,7 @@ from sglang.srt.utils import is_port_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def find_available_ports(base_port: int, count: int) -> List[int]: def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port.""" """Find consecutive available ports starting from base_port."""
available_ports = [] available_ports = []
...@@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]: ...@@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
return available_ports return available_ports
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
...@@ -265,7 +267,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -265,7 +267,9 @@ class MooncakeKVManager(BaseKVManager):
) )
if ret != 0: if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room) self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
continue continue
if kv_chunk.is_last: if kv_chunk.is_last:
...@@ -279,7 +283,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -279,7 +283,9 @@ class MooncakeKVManager(BaseKVManager):
self.request_status[req.room] = ( self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed KVPoll.Success if ret == 0 else KVPoll.Failed
) )
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room) self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
self.transfer_infos.pop(req.room) self.transfer_infos.pop(req.room)
except queue.Empty: except queue.Empty:
...@@ -443,13 +449,14 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -443,13 +449,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
prefill_info = response.json() prefill_info = response.json()
return prefill_info return prefill_info
else: else:
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}") logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}") logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None return None
@cache @cache
def _connect(self, endpoint: str): def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH) socket = zmq.Context().socket(zmq.PUSH)
...@@ -466,17 +473,25 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -466,17 +473,25 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
if prefill_info is None: if prefill_info is None:
logger.error( logger.error(
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}") logger.error(
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
) )
else: else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
prefill_info
)
else: else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
if prefill_info: if prefill_info:
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}" self.prefill_server_url = (
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
)
logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}") logger.info(
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
self.handshake_prefill_server(kv_indices, aux_index) self.handshake_prefill_server(kv_indices, aux_index)
def handshake_prefill_server( def handshake_prefill_server(
...@@ -598,8 +613,13 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -598,8 +613,13 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
# Add lock to make sure thread-safe # Add lock to make sure thread-safe
if role == "Prefill": if role == "Prefill":
async with self.lock: async with self.lock:
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port} self.prefill_port_table[tp_rank] = {
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}") "serve_ip": serve_ip,
"serve_port": serve_port,
}
logger.info(
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}"
)
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
......
...@@ -118,6 +118,7 @@ class VerlEngine: ...@@ -118,6 +118,7 @@ class VerlEngine:
rank=self._tp_rank, rank=self._tp_rank,
dist_group=self._device_mesh_cpu.get_group(), dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(), src=self._device_mesh_cpu.mesh[0].item(),
force_cpu_device=False,
) )
return output return output
......
...@@ -846,9 +846,12 @@ def broadcast_pyobj( ...@@ -846,9 +846,12 @@ def broadcast_pyobj(
rank: int, rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None, dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0, src: int = 0,
force_cpu_device: bool = True,
): ):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)
if rank == 0: if rank == 0:
if len(data) == 0: if len(data) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment