Unverified Commit c4e81e64 authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594)

parent c726d44c
...@@ -30,6 +30,7 @@ class PyNcclCommunicator: ...@@ -30,6 +30,7 @@ class PyNcclCommunicator:
group: Union[ProcessGroup, StatelessProcessGroup], group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
use_current_stream: bool = False,
): ):
""" """
Args: Args:
...@@ -74,6 +75,7 @@ class PyNcclCommunicator: ...@@ -74,6 +75,7 @@ class PyNcclCommunicator:
self.available = True self.available = True
self.disabled = False self.disabled = False
self.use_current_stream = use_current_stream
self.nccl_version = self.nccl.ncclGetRawVersion() self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0: if self.rank == 0:
...@@ -123,6 +125,21 @@ class PyNcclCommunicator: ...@@ -123,6 +125,21 @@ class PyNcclCommunicator:
# when we are using CUDA graph. # when we are using CUDA graph.
self.disabled = True self.disabled = True
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
"""Return the stream to use for NCCL calls.
Behavior mirrors the previous inline logic:
- if an explicit stream is provided, return it
- if stream is None and self.use_current_stream is True, return
torch.cuda.current_stream()
- otherwise return the communicator's default stream (self.stream)
"""
if stream is not None:
return stream
if self.use_current_stream:
return torch.cuda.current_stream()
return self.stream
def all_reduce( def all_reduce(
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
): ):
...@@ -135,8 +152,7 @@ class PyNcclCommunicator: ...@@ -135,8 +152,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
self.nccl.ncclAllReduce( self.nccl.ncclAllReduce(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
...@@ -163,8 +179,7 @@ class PyNcclCommunicator: ...@@ -163,8 +179,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}" f"but the input tensor is on {input_tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
if sizes is not None: if sizes is not None:
split_offset = 0 split_offset = 0
...@@ -210,8 +225,7 @@ class PyNcclCommunicator: ...@@ -210,8 +225,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}" f"but the input tensor is on {input_tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
if sizes is not None: if sizes is not None:
split_offset = 0 split_offset = 0
...@@ -249,8 +263,7 @@ class PyNcclCommunicator: ...@@ -249,8 +263,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
self.nccl.ncclSend( self.nccl.ncclSend(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
...@@ -267,8 +280,7 @@ class PyNcclCommunicator: ...@@ -267,8 +280,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
self.nccl.ncclRecv( self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()),
tensor.numel(), tensor.numel(),
...@@ -285,8 +297,8 @@ class PyNcclCommunicator: ...@@ -285,8 +297,8 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}" f"but the input tensor is on {tensor.device}"
) )
if stream is None: stream = self._resolve_stream(stream)
stream = self.stream
if src == self.rank: if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr()) sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer # NCCL requires the sender also to have a receive buffer
......
...@@ -239,6 +239,7 @@ class GroupCoordinator: ...@@ -239,6 +239,7 @@ class GroupCoordinator:
use_npu_communicator: bool, use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
pynccl_use_current_stream: bool = False,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60), gloo_timeout: timedelta = timedelta(seconds=120 * 60),
): ):
...@@ -289,6 +290,7 @@ class GroupCoordinator: ...@@ -289,6 +290,7 @@ class GroupCoordinator:
# Import communicators # Import communicators
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.pynccl_use_current_stream = pynccl_use_current_stream
self.use_pymscclpp = use_pymscclpp self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem self.use_torch_symm_mem = use_torch_symm_mem
...@@ -322,6 +324,7 @@ class GroupCoordinator: ...@@ -322,6 +324,7 @@ class GroupCoordinator:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
use_current_stream=pynccl_use_current_stream,
) )
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
...@@ -449,10 +452,13 @@ class GroupCoordinator: ...@@ -449,10 +452,13 @@ class GroupCoordinator:
@contextmanager @contextmanager
def graph_capture( def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None self,
graph_capture_context: Optional[GraphCaptureContext] = None,
stream: Optional[torch.cuda.Stream] = None,
): ):
if graph_capture_context is None: if graph_capture_context is None:
stream = self.device_module.Stream() if stream is None:
stream = self.device_module.Stream()
graph_capture_context = GraphCaptureContext(stream) graph_capture_context = GraphCaptureContext(stream)
else: else:
stream = graph_capture_context.stream stream = graph_capture_context.stream
...@@ -1278,6 +1284,7 @@ def init_model_parallel_group( ...@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None, use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True,
use_symm_mem_allreduce: Optional[bool] = None, use_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
...@@ -1300,6 +1307,7 @@ def init_model_parallel_group( ...@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
use_npu_communicator=True, use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster, use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name, group_name=group_name,
pynccl_use_current_stream=pynccl_use_current_stream,
torch_compile=torch_compile, torch_compile=torch_compile,
) )
...@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group ...@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
@contextmanager @contextmanager
def graph_capture(): def graph_capture(stream: Optional[torch.cuda.Stream] = None):
""" """
`graph_capture` is a context manager which should surround the code that `graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the is capturing the CUDA graph. Its main purpose is to ensure that the
...@@ -1371,9 +1379,9 @@ def graph_capture(): ...@@ -1371,9 +1379,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream. from other kernels possibly launched on background in the default stream.
""" """
with get_tp_group().graph_capture() as context, get_pp_group().graph_capture( with get_tp_group().graph_capture(
context stream=stream
): ) as context, get_pp_group().graph_capture(context):
yield context yield context
...@@ -1527,6 +1535,7 @@ def initialize_model_parallel( ...@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
), ),
group_name="tp", group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile, torch_compile=torch_compile,
) )
...@@ -1543,10 +1552,12 @@ def initialize_model_parallel( ...@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
), ),
group_name="pdmux_prefill_tp", group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
torch_compile=torch_compile, torch_compile=torch_compile,
) )
_TP.pynccl_comm.disabled = False if _TP.pynccl_comm:
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False _TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size moe_tp_size = tensor_model_parallel_size // moe_ep_size
...@@ -1737,6 +1748,11 @@ def destroy_model_parallel(): ...@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
_PP.destroy() _PP.destroy()
_PP = None _PP = None
global _PDMUX_PREFILL_TP_GROUP
if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr]
_PDMUX_PREFILL_TP_GROUP.destroy()
_PDMUX_PREFILL_TP_GROUP = None
def destroy_distributed_environment(): def destroy_distributed_environment():
global _WORLD global _WORLD
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment