Unverified Commit 097725bb authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up parallel_state.py (#11148)

parent 44b1fbe2
......@@ -4,7 +4,7 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""vLLM distributed state.
"""Distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
......@@ -53,19 +53,26 @@ from sglang.srt.utils import (
_is_npu = is_npu()
_is_cpu = is_cpu()
_supports_custom_op = supports_custom_op()
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
@dataclass
class P2PWork:
work: Optional[torch.distributed.Work]
payload: Optional[torch.Tensor]
def _split_tensor_dict(
......@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group)
if supports_custom_op():
if _supports_custom_op:
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
......@@ -277,7 +284,7 @@ class GroupCoordinator:
self.use_npu_communicator = use_npu_communicator
self.use_message_queue_broadcaster = use_message_queue_broadcaster
# lazy import to avoid documentation build error
# Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
)
......@@ -497,7 +504,7 @@ class GroupCoordinator:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
if not supports_custom_op():
if not _supports_custom_op:
self._all_reduce_in_place(input_)
return input_
......@@ -523,23 +530,24 @@ class GroupCoordinator:
outplace_all_reduce_method = None
if (
self.qr_comm is not None
and not self.qr_comm.disabled
and self.qr_comm.should_quick_allreduce(input_)
):
outplace_all_reduce_method = "qr"
elif (
self.ca_comm is not None
and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_)
):
outplace_all_reduce_method = "ca"
elif (
self.qr_comm is not None
and not self.qr_comm.disabled
and self.qr_comm.should_quick_allreduce(input_)
):
outplace_all_reduce_method = "qr"
elif (
self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
):
outplace_all_reduce_method = "pymscclpp"
if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce(
input_,
......@@ -553,16 +561,16 @@ class GroupCoordinator:
def _all_reduce_out_place(
self, input_: torch.Tensor, outplace_all_reduce_method: str
) -> torch.Tensor:
qr_comm = self.qr_comm
ca_comm = self.ca_comm
qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm
assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "ca":
if outplace_all_reduce_method == "ca":
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
elif outplace_all_reduce_method == "qr":
assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_)
else:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
......@@ -637,7 +645,7 @@ class GroupCoordinator:
)
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not supports_custom_op():
if _is_npu or not _supports_custom_op:
self._all_gather_into_tensor(output, input)
else:
torch.ops.sglang.reg_all_gather_into_tensor(
......@@ -697,15 +705,13 @@ class GroupCoordinator:
)
# All-gather.
if input_.is_cpu and is_shm_available(
input_.dtype, self.world_size, self.local_size
):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
if input_.is_cpu:
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
if is_shm_available(input_.dtype, self.world_size, self.local_size):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
else:
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
else:
self.all_gather_into_tensor(output_tensor, input_)
......@@ -861,45 +867,63 @@ class GroupCoordinator:
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
return objs
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
def send_object(
self,
obj: Any,
dst: int,
async_send: bool = False,
) -> List[P2PWork]:
"""
Send the input object list to the destination rank.
This function uses the CPU group for all communications.
assert dst < self.world_size, f"Invalid dst rank ({dst})"
TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
NOTE: `dst` is the local rank of the destination rank.
"""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
send_func = torch.distributed.isend if async_send else torch.distributed.send
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=torch.cuda.current_device()
)
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor(
[object_tensor.numel()],
dtype=torch.long,
device="cpu",
[object_tensor.numel()], dtype=torch.long, device="cpu"
)
# Send object size
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
p2p_work = []
size_work = send_func(
size_tensor,
self.ranks[dst],
group=self.cpu_group,
)
if async_send:
p2p_work.append(P2PWork(size_work, size_tensor))
# Send object
torch.distributed.send(
object_work = send_func(
object_tensor,
dst=self.ranks[dst],
group=self.device_group,
self.ranks[dst],
group=self.cpu_group,
)
if async_send:
p2p_work.append(P2PWork(object_work, object_tensor))
return None
return p2p_work
def recv_object(self, src: int) -> Any:
def recv_object(
self,
src: int,
) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert (
src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank."
......@@ -907,27 +931,25 @@ class GroupCoordinator:
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(
# We have to use irecv here to make it work for both isend and send.
work = torch.distributed.irecv(
size_tensor, src=self.ranks[src], group=self.cpu_group
)
work.wait()
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
object_tensor: Any = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=torch.cuda.current_device(),
device="cpu",
)
rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.device_group
work = torch.distributed.irecv(
object_tensor, src=self.ranks[src], group=self.cpu_group
)
work.wait()
assert (
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.cpu().numpy())
obj = pickle.loads(object_tensor.numpy())
return obj
def broadcast_tensor_dict(
......@@ -1017,12 +1039,13 @@ class GroupCoordinator:
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
async_send: bool = False,
) -> Optional[List[P2PWork]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
if self.world_size == 1:
return tensor_dict
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
......@@ -1047,7 +1070,10 @@ class GroupCoordinator:
# 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach.
self.send_object(metadata_list, dst=dst)
send_func = torch.distributed.isend if async_send else torch.distributed.send
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
......@@ -1057,15 +1083,10 @@ class GroupCoordinator:
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(
tensor, dst=self.ranks[dst], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
comm_group = metadata_group if tensor.is_cpu else group
work = send_func(tensor, self.ranks[dst], group=comm_group)
p2p_works.append(P2PWork(work, tensor))
return p2p_works
def recv_tensor_dict(
self,
......@@ -1111,17 +1132,15 @@ class GroupCoordinator:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(
tensor, src=self.ranks[src], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
# We have to use irecv here to make it work for both isend and send.
comm_group = metadata_group if tensor.is_cpu else group
work = torch.distributed.irecv(
tensor, src=self.ranks[src], group=comm_group
)
work.wait()
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
tensor = all_gather_group.all_gather(tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment