Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
96a5a949
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "8169c6f4ef5c5d5705fbb8309dc7a27544bd0a37"
Unverified
Commit
96a5a949
authored
Oct 26, 2025
by
zyksir
Committed by
GitHub
Oct 26, 2025
Browse files
[Fix] fix allreduce bug in Piecewise Graph (#12106)
parent
ea385ae8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
14 deletions
+12
-14
python/sglang/srt/compilation/backend.py
python/sglang/srt/compilation/backend.py
+1
-1
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+0
-7
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
.../sglang/srt/model_executor/piecewise_cuda_graph_runner.py
+11
-6
No files found.
python/sglang/srt/compilation/backend.py
View file @
96a5a949
...
@@ -392,7 +392,7 @@ class SGLangBackend:
...
@@ -392,7 +392,7 @@ class SGLangBackend:
self
.
configure_post_pass
()
self
.
configure_post_pass
()
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
[
"sglang.unified_attention_with_output"
]
graph
,
[
"sglang.unified_attention_with_output"
,
"sglang.inplace_all_reduce"
]
)
)
from
torch._dynamo.utils
import
lazy_format_graph_code
from
torch._dynamo.utils
import
lazy_format_graph_code
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
96a5a949
...
@@ -340,17 +340,10 @@ class GroupCoordinator:
...
@@ -340,17 +340,10 @@ class GroupCoordinator:
self
.
qr_comm
:
Optional
[
QuickAllReduce
]
=
None
self
.
qr_comm
:
Optional
[
QuickAllReduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
# Initialize a custom fast all-reduce implementation.
if
torch_compile
is
not
None
and
torch_compile
:
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
# avoid illegal cuda memory access.
ca_max_size
=
256
*
1024
*
1024
else
:
ca_max_size
=
8
*
1024
*
1024
try
:
try
:
self
.
ca_comm
=
CustomAllreduce
(
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
max_size
=
ca_max_size
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
warning
(
logger
.
warning
(
...
...
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
View file @
96a5a949
...
@@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
...
@@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
set_graph_pool_id
,
set_graph_pool_id
,
)
)
from
sglang.srt.distributed.parallel_state
import
graph_capture
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
DpPaddingMode
,
DpPaddingMode
,
get_attention_tp_rank
,
get_attention_tp_rank
,
...
@@ -281,10 +280,10 @@ class PiecewiseCudaGraphRunner:
...
@@ -281,10 +280,10 @@ class PiecewiseCudaGraphRunner:
# Trigger CUDA graph capture for specific shapes.
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# can reuse the memory pool allocated for the large shapes.
with
freeze_gc
(
with
freeze_gc
(
self
.
model_runner
.
server_args
.
enable_cudagraph_gc
):
self
.
model_runner
.
server_args
.
enable_cudagraph_gc
if
self
.
model_runner
.
tp_group
.
ca_comm
is
not
None
:
),
graph_capture
()
as
graph_capture_context
:
old_ca_disable
=
self
.
model_runner
.
tp_group
.
ca_comm
.
disabled
self
.
stream
=
graph_capture_context
.
stream
self
.
model_runner
.
tp_group
.
ca_comm
.
disabled
=
True
avail_mem
=
get_available_gpu_memory
(
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
self
.
model_runner
.
gpu_id
,
...
@@ -312,9 +311,10 @@ class PiecewiseCudaGraphRunner:
...
@@ -312,9 +311,10 @@ class PiecewiseCudaGraphRunner:
# Save gemlite cache after each capture
# Save gemlite cache after each capture
save_gemlite_cache
()
save_gemlite_cache
()
if
self
.
model_runner
.
tp_group
.
ca_comm
is
not
None
:
self
.
model_runner
.
tp_group
.
ca_comm
.
disabled
=
old_ca_disable
def
capture_one_batch_size
(
self
,
num_tokens
:
int
):
def
capture_one_batch_size
(
self
,
num_tokens
:
int
):
stream
=
self
.
stream
bs
=
1
bs
=
1
# Graph inputs
# Graph inputs
...
@@ -479,6 +479,9 @@ class PiecewiseCudaGraphRunner:
...
@@ -479,6 +479,9 @@ class PiecewiseCudaGraphRunner:
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
self
.
model_runner
.
tp_group
.
ca_comm
is
not
None
:
old_ca_disable
=
self
.
model_runner
.
tp_group
.
ca_comm
.
disabled
self
.
model_runner
.
tp_group
.
ca_comm
.
disabled
=
True
static_forward_batch
=
self
.
replay_prepare
(
forward_batch
,
**
kwargs
)
static_forward_batch
=
self
.
replay_prepare
(
forward_batch
,
**
kwargs
)
# Replay
# Replay
with
set_forward_context
(
static_forward_batch
,
self
.
attention_layers
):
with
set_forward_context
(
static_forward_batch
,
self
.
attention_layers
):
...
@@ -504,6 +507,8 @@ class PiecewiseCudaGraphRunner:
...
@@ -504,6 +507,8 @@ class PiecewiseCudaGraphRunner:
raise
NotImplementedError
(
raise
NotImplementedError
(
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
"PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
)
)
if
self
.
model_runner
.
tp_group
.
ca_comm
is
not
None
:
self
.
model_runner
.
tp_group
.
ca_comm
.
disabled
=
old_ca_disable
def
get_spec_info
(
self
,
num_tokens
:
int
):
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
spec_info
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment