Unverified Commit 96a5a949 authored by zyksir's avatar zyksir Committed by GitHub
Browse files

[Fix] fix allreduce bug in Piecewise Graph (#12106)

parent ea385ae8
......@@ -392,7 +392,7 @@ class SGLangBackend:
self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph(
graph, ["sglang.unified_attention_with_output"]
graph, ["sglang.unified_attention_with_output", "sglang.inplace_all_reduce"]
)
from torch._dynamo.utils import lazy_format_graph_code
......
......@@ -340,17 +340,10 @@ class GroupCoordinator:
self.qr_comm: Optional[QuickAllReduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
if torch_compile is not None and torch_compile:
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
# avoid illegal cuda memory access.
ca_max_size = 256 * 1024 * 1024
else:
ca_max_size = 8 * 1024 * 1024
try:
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
max_size=ca_max_size,
)
except Exception as e:
logger.warning(
......
......@@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id,
)
from sglang.srt.distributed.parallel_state import graph_capture
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
get_attention_tp_rank,
......@@ -281,10 +280,10 @@ class PiecewiseCudaGraphRunner:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with freeze_gc(
self.model_runner.server_args.enable_cudagraph_gc
), graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
if self.model_runner.tp_group.ca_comm is not None:
old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
self.model_runner.tp_group.ca_comm.disabled = True
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
......@@ -312,9 +311,10 @@ class PiecewiseCudaGraphRunner:
# Save gemlite cache after each capture
save_gemlite_cache()
if self.model_runner.tp_group.ca_comm is not None:
self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
def capture_one_batch_size(self, num_tokens: int):
stream = self.stream
bs = 1
# Graph inputs
......@@ -479,6 +479,9 @@ class PiecewiseCudaGraphRunner:
forward_batch: ForwardBatch,
**kwargs,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
if self.model_runner.tp_group.ca_comm is not None:
old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
self.model_runner.tp_group.ca_comm.disabled = True
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
# Replay
with set_forward_context(static_forward_batch, self.attention_layers):
......@@ -504,6 +507,8 @@ class PiecewiseCudaGraphRunner:
raise NotImplementedError(
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
)
if self.model_runner.tp_group.ca_comm is not None:
self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
def get_spec_info(self, num_tokens: int):
spec_info = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment