Unverified Commit 097725bb authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up parallel_state.py (#11148)

parent 44b1fbe2
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""vLLM distributed state. """Distributed state.
It takes over the control of the distributed environment from PyTorch. It takes over the control of the distributed environment from PyTorch.
The typical workflow is: The typical workflow is:
...@@ -53,19 +53,26 @@ from sglang.srt.utils import ( ...@@ -53,19 +53,26 @@ from sglang.srt.utils import (
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_supports_custom_op = supports_custom_op()
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS") IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) @dataclass
class P2PWork:
# use int value instead of ReduceOp.SUM to support torch compile work: Optional[torch.distributed.Work]
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) payload: Optional[torch.Tensor]
def _split_tensor_dict( def _split_tensor_dict(
...@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None: ...@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group) _groups[group.unique_name] = weakref.ref(group)
if supports_custom_op(): if _supports_custom_op:
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found." assert group_name in _groups, f"Group {group_name} is not found."
...@@ -277,7 +284,7 @@ class GroupCoordinator: ...@@ -277,7 +284,7 @@ class GroupCoordinator:
self.use_npu_communicator = use_npu_communicator self.use_npu_communicator = use_npu_communicator
self.use_message_queue_broadcaster = use_message_queue_broadcaster self.use_message_queue_broadcaster = use_message_queue_broadcaster
# lazy import to avoid documentation build error # Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce, CustomAllreduce,
) )
...@@ -497,7 +504,7 @@ class GroupCoordinator: ...@@ -497,7 +504,7 @@ class GroupCoordinator:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
return input_ return input_
if not supports_custom_op(): if not _supports_custom_op:
self._all_reduce_in_place(input_) self._all_reduce_in_place(input_)
return input_ return input_
...@@ -523,23 +530,24 @@ class GroupCoordinator: ...@@ -523,23 +530,24 @@ class GroupCoordinator:
outplace_all_reduce_method = None outplace_all_reduce_method = None
if ( if (
self.qr_comm is not None
and not self.qr_comm.disabled
and self.qr_comm.should_quick_allreduce(input_)
):
outplace_all_reduce_method = "qr"
elif (
self.ca_comm is not None self.ca_comm is not None
and not self.ca_comm.disabled and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_) and self.ca_comm.should_custom_ar(input_)
): ):
outplace_all_reduce_method = "ca" outplace_all_reduce_method = "ca"
elif (
self.qr_comm is not None
and not self.qr_comm.disabled
and self.qr_comm.should_quick_allreduce(input_)
):
outplace_all_reduce_method = "qr"
elif ( elif (
self.pymscclpp_comm is not None self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_) and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
): ):
outplace_all_reduce_method = "pymscclpp" outplace_all_reduce_method = "pymscclpp"
if outplace_all_reduce_method is not None: if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce( return torch.ops.sglang.outplace_all_reduce(
input_, input_,
...@@ -553,16 +561,16 @@ class GroupCoordinator: ...@@ -553,16 +561,16 @@ class GroupCoordinator:
def _all_reduce_out_place( def _all_reduce_out_place(
self, input_: torch.Tensor, outplace_all_reduce_method: str self, input_: torch.Tensor, outplace_all_reduce_method: str
) -> torch.Tensor: ) -> torch.Tensor:
qr_comm = self.qr_comm
ca_comm = self.ca_comm ca_comm = self.ca_comm
qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm pymscclpp_comm = self.pymscclpp_comm
assert any([qr_comm, ca_comm, pymscclpp_comm]) assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "qr": if outplace_all_reduce_method == "ca":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "ca":
assert not ca_comm.disabled assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_) out = ca_comm.custom_all_reduce(input_)
elif outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
else: else:
assert not pymscclpp_comm.disabled assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_) out = pymscclpp_comm.all_reduce(input_)
...@@ -637,7 +645,7 @@ class GroupCoordinator: ...@@ -637,7 +645,7 @@ class GroupCoordinator:
) )
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not supports_custom_op(): if _is_npu or not _supports_custom_op:
self._all_gather_into_tensor(output, input) self._all_gather_into_tensor(output, input)
else: else:
torch.ops.sglang.reg_all_gather_into_tensor( torch.ops.sglang.reg_all_gather_into_tensor(
...@@ -697,12 +705,10 @@ class GroupCoordinator: ...@@ -697,12 +705,10 @@ class GroupCoordinator:
) )
# All-gather. # All-gather.
if input_.is_cpu and is_shm_available(
input_.dtype, self.world_size, self.local_size
):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
if input_.is_cpu: if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
else:
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group output_tensor, input_, group=self.device_group
) )
...@@ -861,45 +867,63 @@ class GroupCoordinator: ...@@ -861,45 +867,63 @@ class GroupCoordinator:
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group) torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
return objs return objs
def send_object(self, obj: Any, dst: int) -> None: def send_object(
"""Send the input object list to the destination rank.""" self,
"""NOTE: `dst` is the local rank of the destination rank.""" obj: Any,
dst: int,
async_send: bool = False,
) -> List[P2PWork]:
"""
Send the input object list to the destination rank.
This function uses the CPU group for all communications.
assert dst < self.world_size, f"Invalid dst rank ({dst})" TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
NOTE: `dst` is the local rank of the destination rank.
"""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank_in_group, ( assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same " "Invalid destination rank. Destination rank is the same "
"as the current rank." "as the current rank."
) )
send_func = torch.distributed.isend if async_send else torch.distributed.send
# Serialize object to tensor and get the size as well # Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda( object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
device=torch.cuda.current_device()
)
size_tensor = torch.tensor( size_tensor = torch.tensor(
[object_tensor.numel()], [object_tensor.numel()], dtype=torch.long, device="cpu"
dtype=torch.long,
device="cpu",
) )
# Send object size # Send object size
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) p2p_work = []
size_work = send_func(
size_tensor,
self.ranks[dst],
group=self.cpu_group,
)
if async_send:
p2p_work.append(P2PWork(size_work, size_tensor))
# Send object object_work = send_func(
torch.distributed.send(
object_tensor, object_tensor,
dst=self.ranks[dst], self.ranks[dst],
group=self.device_group, group=self.cpu_group,
) )
if async_send:
p2p_work.append(P2PWork(object_work, object_tensor))
return None return p2p_work
def recv_object(self, src: int) -> Any: def recv_object(
self,
src: int,
) -> Any:
"""Receive the input object list from the source rank.""" """Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank.""" """NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
assert ( assert (
src != self.rank_in_group src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank." ), "Invalid source rank. Source rank is the same as the current rank."
...@@ -907,27 +931,25 @@ class GroupCoordinator: ...@@ -907,27 +931,25 @@ class GroupCoordinator:
size_tensor = torch.empty(1, dtype=torch.long, device="cpu") size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size # Receive object size
rank_size = torch.distributed.recv( # We have to use irecv here to make it work for both isend and send.
work = torch.distributed.irecv(
size_tensor, src=self.ranks[src], group=self.cpu_group size_tensor, src=self.ranks[src], group=self.cpu_group
) )
work.wait()
# Tensor to receive serialized objects into. # Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload] object_tensor: Any = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type] size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8, dtype=torch.uint8,
device=torch.cuda.current_device(), device="cpu",
) )
rank_object = torch.distributed.recv( work = torch.distributed.irecv(
object_tensor, src=self.ranks[src], group=self.device_group object_tensor, src=self.ranks[src], group=self.cpu_group
) )
work.wait()
assert ( obj = pickle.loads(object_tensor.numpy())
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.cpu().numpy())
return obj return obj
def broadcast_tensor_dict( def broadcast_tensor_dict(
...@@ -1017,12 +1039,13 @@ class GroupCoordinator: ...@@ -1017,12 +1039,13 @@ class GroupCoordinator:
tensor_dict: Dict[str, Union[torch.Tensor, Any]], tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None, dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None, all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: async_send: bool = False,
) -> Optional[List[P2PWork]]:
"""Send the input tensor dictionary. """Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank. NOTE: `dst` is the local rank of the source rank.
""" """
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1: if self.world_size == 1:
return tensor_dict return tensor_dict
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
...@@ -1047,7 +1070,10 @@ class GroupCoordinator: ...@@ -1047,7 +1070,10 @@ class GroupCoordinator:
# 1. Superior D2D transfer bandwidth # 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations # 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach. # Thus the net performance gain justifies this approach.
self.send_object(metadata_list, dst=dst)
send_func = torch.distributed.isend if async_send else torch.distributed.send
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
for tensor in tensor_list: for tensor in tensor_list:
if tensor.numel() == 0: if tensor.numel() == 0:
# Skip sending empty tensors. # Skip sending empty tensors.
...@@ -1057,15 +1083,10 @@ class GroupCoordinator: ...@@ -1057,15 +1083,10 @@ class GroupCoordinator:
if all_gather_group is not None and tensor.numel() % all_gather_size == 0: if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu: comm_group = metadata_group if tensor.is_cpu else group
# use metadata_group for CPU tensors work = send_func(tensor, self.ranks[dst], group=comm_group)
torch.distributed.send( p2p_works.append(P2PWork(work, tensor))
tensor, dst=self.ranks[dst], group=metadata_group return p2p_works
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict( def recv_tensor_dict(
self, self,
...@@ -1111,17 +1132,15 @@ class GroupCoordinator: ...@@ -1111,17 +1132,15 @@ class GroupCoordinator:
orig_shape = tensor.shape orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu: # We have to use irecv here to make it work for both isend and send.
# use metadata_group for CPU tensors comm_group = metadata_group if tensor.is_cpu else group
torch.distributed.recv( work = torch.distributed.irecv(
tensor, src=self.ranks[src], group=metadata_group tensor, src=self.ranks[src], group=comm_group
) )
else: work.wait()
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
if use_all_gather: if use_all_gather:
# do the allgather tensor = all_gather_group.all_gather(tensor, dim=0)
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
tensor = tensor.reshape(orig_shape) tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor tensor_dict[key] = tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment