Unverified Commit c4e81e64 authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594)

parent c726d44c
......@@ -30,6 +30,7 @@ class PyNcclCommunicator:
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
use_current_stream: bool = False,
):
"""
Args:
......@@ -74,6 +75,7 @@ class PyNcclCommunicator:
self.available = True
self.disabled = False
self.use_current_stream = use_current_stream
self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0:
......@@ -123,6 +125,21 @@ class PyNcclCommunicator:
# when we are using CUDA graph.
self.disabled = True
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
"""Return the stream to use for NCCL calls.
Behavior mirrors the previous inline logic:
- if an explicit stream is provided, return it
- if stream is None and self.use_current_stream is True, return
torch.cuda.current_stream()
- otherwise return the communicator's default stream (self.stream)
"""
if stream is not None:
return stream
if self.use_current_stream:
return torch.cuda.current_stream()
return self.stream
def all_reduce(
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
):
......@@ -135,8 +152,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
stream = self._resolve_stream(stream)
self.nccl.ncclAllReduce(
buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()),
......@@ -163,8 +179,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
stream = self._resolve_stream(stream)
if sizes is not None:
split_offset = 0
......@@ -210,8 +225,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
stream = self._resolve_stream(stream)
if sizes is not None:
split_offset = 0
......@@ -249,8 +263,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
stream = self._resolve_stream(stream)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
......@@ -267,8 +280,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
stream = self._resolve_stream(stream)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
......@@ -285,8 +297,8 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
stream = self._resolve_stream(stream)
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
......
......@@ -239,6 +239,7 @@ class GroupCoordinator:
use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
pynccl_use_current_stream: bool = False,
torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
):
......@@ -289,6 +290,7 @@ class GroupCoordinator:
# Import communicators
self.use_pynccl = use_pynccl
self.pynccl_use_current_stream = pynccl_use_current_stream
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
......@@ -322,6 +324,7 @@ class GroupCoordinator:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
use_current_stream=pynccl_use_current_stream,
)
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
......@@ -449,10 +452,13 @@ class GroupCoordinator:
@contextmanager
def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None
self,
graph_capture_context: Optional[GraphCaptureContext] = None,
stream: Optional[torch.cuda.Stream] = None,
):
if graph_capture_context is None:
stream = self.device_module.Stream()
if stream is None:
stream = self.device_module.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
......@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True,
use_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None,
) -> GroupCoordinator:
......@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
pynccl_use_current_stream=pynccl_use_current_stream,
torch_compile=torch_compile,
)
......@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
@contextmanager
def graph_capture():
def graph_capture(stream: Optional[torch.cuda.Stream] = None):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
......@@ -1371,9 +1379,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
context
):
with get_tp_group().graph_capture(
stream=stream
) as context, get_pp_group().graph_capture(context):
yield context
......@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile,
)
......@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
torch_compile=torch_compile,
)
_TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
if _TP.pynccl_comm:
_TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size
......@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
_PP.destroy()
_PP = None
global _PDMUX_PREFILL_TP_GROUP
if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr]
_PDMUX_PREFILL_TP_GROUP.destroy()
_PDMUX_PREFILL_TP_GROUP = None
def destroy_distributed_environment():
global _WORLD
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment