Unverified Commit 1e61b496 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update parallel_state.py (20250830) (#9828)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 300676af
...@@ -52,6 +52,8 @@ from sglang.srt.utils import ( ...@@ -52,6 +52,8 @@ from sglang.srt.utils import (
_is_npu = is_npu() _is_npu = is_npu()
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
...@@ -223,10 +225,12 @@ class GroupCoordinator: ...@@ -223,10 +225,12 @@ class GroupCoordinator:
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
): ):
# Set group info
group_name = group_name or "anonymous" group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name) self.unique_name = _get_unique_name(group_name)
_register_group(self) _register_group(self)
# Set rank info
self.rank = torch.distributed.get_rank() self.rank = torch.distributed.get_rank()
self.local_rank = local_rank self.local_rank = local_rank
self.device_group = None self.device_group = None
...@@ -250,14 +254,16 @@ class GroupCoordinator: ...@@ -250,14 +254,16 @@ class GroupCoordinator:
assert self.cpu_group is not None assert self.cpu_group is not None
assert self.device_group is not None assert self.device_group is not None
device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
if is_cuda_alike(): if is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}") self.device = torch.device(f"cuda:{device_id}")
elif _is_npu: elif _is_npu:
self.device = torch.device(f"npu:{local_rank}") self.device = torch.device(f"npu:{device_id}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
# Import communicators
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
...@@ -270,6 +276,9 @@ class GroupCoordinator: ...@@ -270,6 +276,9 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce, CustomAllreduce,
) )
from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator,
)
from sglang.srt.distributed.device_communicators.pynccl import ( from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator, PyNcclCommunicator,
) )
...@@ -287,10 +296,6 @@ class GroupCoordinator: ...@@ -287,10 +296,6 @@ class GroupCoordinator:
device=self.device, device=self.device,
) )
from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator,
)
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
if use_pymscclpp and self.world_size > 1: if use_pymscclpp and self.world_size > 1:
self.pymscclpp_comm = PyMscclppCommunicator( self.pymscclpp_comm = PyMscclppCommunicator(
...@@ -325,30 +330,30 @@ class GroupCoordinator: ...@@ -325,30 +330,30 @@ class GroupCoordinator:
except Exception as e: except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}") logger.warning(f"Failed to initialize QuickAllReduce: {e}")
# Create communicator for other hardware backends
from sglang.srt.distributed.device_communicators.hpu_communicator import ( from sglang.srt.distributed.device_communicators.hpu_communicator import (
HpuCommunicator, HpuCommunicator,
) )
from sglang.srt.distributed.device_communicators.npu_communicator import (
NpuCommunicator,
)
from sglang.srt.distributed.device_communicators.xpu_communicator import (
XpuCommunicator,
)
self.hpu_communicator: Optional[HpuCommunicator] = None self.hpu_communicator: Optional[HpuCommunicator] = None
if use_hpu_communicator and self.world_size > 1: if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group) self.hpu_communicator = HpuCommunicator(group=self.device_group)
from sglang.srt.distributed.device_communicators.xpu_communicator import (
XpuCommunicator,
)
self.xpu_communicator: Optional[XpuCommunicator] = None self.xpu_communicator: Optional[XpuCommunicator] = None
if use_xpu_communicator and self.world_size > 1: if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group) self.xpu_communicator = XpuCommunicator(group=self.device_group)
from sglang.srt.distributed.device_communicators.npu_communicator import (
NpuCommunicator,
)
self.npu_communicator: Optional[NpuCommunicator] = None self.npu_communicator: Optional[NpuCommunicator] = None
if use_npu_communicator and self.world_size > 1: if use_npu_communicator and self.world_size > 1:
self.npu_communicator = NpuCommunicator(group=self.device_group) self.npu_communicator = NpuCommunicator(group=self.device_group)
# Create message queue
from sglang.srt.distributed.device_communicators.shm_broadcast import ( from sglang.srt.distributed.device_communicators.shm_broadcast import (
MessageQueue, MessageQueue,
) )
...@@ -848,6 +853,11 @@ class GroupCoordinator: ...@@ -848,6 +853,11 @@ class GroupCoordinator:
) )
return obj_list return obj_list
def all_gather_object(self, obj: Any) -> List[Any]:
objs = [None] * self.world_size
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
return objs
def send_object(self, obj: Any, dst: int) -> None: def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank.""" """Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank.""" """NOTE: `dst` is the local rank of the destination rank."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment